11import pytest
22import torch
3+ from torch .nn import Linear
34
45from pina import LabelTensor
56from pina .model import DeepONet
89data = torch .rand ((20 , 3 ))
910input_vars = ['a' , 'b' , 'c' ]
1011input_ = LabelTensor (data , input_vars )
11-
12+ symbol_funcs_red = DeepONet ._symbol_functions (dim = - 1 )
13+ output_dims = [1 , 5 , 10 , 20 ]
1214
1315def test_constructor ():
1416 branch_net = FeedForward (input_dimensions = 1 , output_dimensions = 10 )
@@ -32,7 +34,6 @@ def test_constructor_fails_when_invalid_inner_layer_size():
3234 reduction = '+' ,
3335 aggregator = '*' )
3436
35-
3637def test_forward_extract_str ():
3738 branch_net = FeedForward (input_dimensions = 1 , output_dimensions = 10 )
3839 trunk_net = FeedForward (input_dimensions = 2 , output_dimensions = 10 )
@@ -43,6 +44,7 @@ def test_forward_extract_str():
4344 reduction = '+' ,
4445 aggregator = '*' )
4546 model (input_ )
47+ assert model (input_ ).shape [- 1 ] == 1
4648
4749
4850def test_forward_extract_int ():
@@ -100,3 +102,30 @@ def test_backward_extract_str_wrong():
100102 l = torch .mean (model (data ))
101103 l .backward ()
102104 assert data ._grad .shape == torch .Size ([20 ,3 ])
105+
106+ @pytest .mark .parametrize ('red' , symbol_funcs_red )
107+ def test_forward_symbol_funcs (red ):
108+ branch_net = FeedForward (input_dimensions = 1 , output_dimensions = 10 )
109+ trunk_net = FeedForward (input_dimensions = 2 , output_dimensions = 10 )
110+ model = DeepONet (branch_net = branch_net ,
111+ trunk_net = trunk_net ,
112+ input_indeces_branch_net = ['a' ],
113+ input_indeces_trunk_net = ['b' , 'c' ],
114+ reduction = red ,
115+ aggregator = '*' )
116+ model (input_ )
117+ assert model (input_ ).shape [- 1 ] == 1
118+
119+ @pytest .mark .parametrize ('out_dim' , output_dims )
120+ def test_forward_callable_reduction (out_dim ):
121+ branch_net = FeedForward (input_dimensions = 1 , output_dimensions = 10 )
122+ trunk_net = FeedForward (input_dimensions = 2 , output_dimensions = 10 )
123+ reduction_layer = Linear (10 , out_dim )
124+ model = DeepONet (branch_net = branch_net ,
125+ trunk_net = trunk_net ,
126+ input_indeces_branch_net = ['a' ],
127+ input_indeces_trunk_net = ['b' , 'c' ],
128+ reduction = reduction_layer ,
129+ aggregator = '*' )
130+ model (input_ )
131+ assert model (input_ ).shape [- 1 ] == out_dim
0 commit comments