File tree Expand file tree Collapse file tree 1 file changed +12
-0
lines changed
Expand file tree Collapse file tree 1 file changed +12
-0
lines changed Original file line number Diff line number Diff line change @@ -71,6 +71,18 @@ def test_forward_inverse_are_consistent(self):
7171 with self .subTest (shape = shape ):
7272 self .assert_forward_inverse_are_consistent (transform , inputs )
7373
74+ def test_scale_activation_has_an_effect (self ):
75+ for shape in self .shapes :
76+ inputs = torch .randn (batch_size , * shape )
77+ transform , mask = create_coupling_transform (
78+ coupling .AffineCouplingTransform , shape
79+ )
80+ outputs_default , logabsdet_default = transform (inputs )
81+ transform .scale_activation = coupling .AffineCouplingTransform .GENERAL_SCALE_ACTIVATION
82+ outputs_general , logabsdet_general = transform (inputs )
83+ with self .subTest (shape = shape ):
84+ self .assertNotEqual (outputs_default , outputs_general )
85+ self .assertNotEqual (logabsdet_default , logabsdet_general )
7486
7587class AdditiveTransformTest (TransformTest ):
7688 shapes = [[20 ], [2 , 4 , 4 ]]
You can’t perform that action at this time.
0 commit comments