Skip to content

Commit cd6e3b7

Browse files
committed
Test that scale_activation changes AffineCouplingTransform results
1 parent b9af0da commit cd6e3b7

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

tests/transforms/coupling_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,18 @@ def test_forward_inverse_are_consistent(self):
7171
with self.subTest(shape=shape):
7272
self.assert_forward_inverse_are_consistent(transform, inputs)
7373

74+
def test_scale_activation_has_an_effect(self):
75+
for shape in self.shapes:
76+
inputs = torch.randn(batch_size, *shape)
77+
transform, mask = create_coupling_transform(
78+
coupling.AffineCouplingTransform, shape
79+
)
80+
outputs_default, logabsdet_default = transform(inputs)
81+
transform.scale_activation = coupling.AffineCouplingTransform.GENERAL_SCALE_ACTIVATION
82+
outputs_general, logabsdet_general = transform(inputs)
83+
with self.subTest(shape=shape):
84+
self.assertNotEqual(outputs_default, outputs_general)
85+
self.assertNotEqual(logabsdet_default, logabsdet_general)
7486

7587
class AdditiveTransformTest(TransformTest):
7688
shapes = [[20], [2, 4, 4]]

0 commit comments

Comments
 (0)