2929@pytest .mark .parametrize ("input_shape" , ["2d" , "3d" ])
3030@pytest .mark .parametrize ("use_soft_flow" , [True , False ])
3131@pytest .mark .parametrize ("permutation" , ["learnable" , "fixed" ])
32- @pytest .mark .parametrize ("coupling_design" , ["affine" , "spline" ])
33- @pytest .mark .parametrize ("num_coupling_layers" , [2 , 8 ])
32+ @pytest .mark .parametrize ("coupling_design" , ["affine" , "spline" , "interleaved" ])
33+ @pytest .mark .parametrize ("num_coupling_layers" , [2 , 7 ])
3434def test_invertible_network (input_shape , use_soft_flow , permutation , coupling_design , num_coupling_layers ):
3535 """Tests the ``InvertibleNetwork`` core class using a couple of relevant configurations."""
3636
3737 # Randomize units and input dim
3838 units = np .random .randint (low = 2 , high = 32 )
3939 input_dim = np .random .randint (low = 2 , high = 32 )
4040
41- # Create settings dictionaries and network
42- coupling_settings = {
43- "dense_args" : dict (units = units , activation = "elu" ),
44- "num_dense" : 1 ,
45- }
41+ # Create settings dictionaries
42+ if coupling_design in ["affine" , "spline" ]:
43+ coupling_settings = {
44+ "dense_args" : dict (units = units , activation = "elu" ),
45+ "num_dense" : 1 ,
46+ }
47+ else :
48+ coupling_settings = {
49+ "affine" : dict (dense_args = {"units" : units , "activation" : "selu" }, num_dense = 1 ),
50+ "spline" : dict (dense_args = {"units" : units , "activation" : "relu" }, bins = 8 , num_dense = 1 ),
51+ }
4652
53+ # Create invertible network with test settings
4754 network = InvertibleNetwork (
4855 num_params = input_dim ,
4956 num_coupling_layers = num_coupling_layers ,
@@ -72,7 +79,7 @@ def test_invertible_network(input_shape, use_soft_flow, permutation, coupling_de
7279 assert network .latent_dim == input_dim
7380 assert len (network .coupling_layers ) == num_coupling_layers
7481 # Test layer attributes
75- for l in network .coupling_layers :
82+ for idx , l in enumerate ( network .coupling_layers ) :
7683 # Permutation
7784 if permutation == "fixed" :
7885 assert isinstance (l .permutation , Permutation )
@@ -85,6 +92,11 @@ def test_invertible_network(input_shape, use_soft_flow, permutation, coupling_de
8592 assert isinstance (l .net1 , AffineCoupling ) and isinstance (l .net2 , AffineCoupling )
8693 elif coupling_design == "spline" :
8794 assert isinstance (l .net1 , SplineCoupling ) and isinstance (l .net2 , SplineCoupling )
95+ elif coupling_design == "interleaved" :
96+ if idx % 2 == 0 :
97+ assert isinstance (l .net1 , AffineCoupling ) and isinstance (l .net2 , AffineCoupling )
98+ else :
99+ assert isinstance (l .net1 , SplineCoupling ) and isinstance (l .net2 , SplineCoupling )
88100
89101 if use_soft_flow :
90102 assert network .soft_flow is True
0 commit comments