1010def test_forward_standardization_training ():
1111 random_input = keras .random .normal ((8 , 4 ))
1212
13- layer = Standardization (momentum = 0.0 ) # no EMA for test stability
13+ layer = Standardization ()
1414 layer .build (random_input .shape )
1515
1616 out = layer (random_input , stage = "training" )
1717
1818 moving_mean = keras .ops .convert_to_numpy (layer .moving_mean [0 ])
19- moving_std = keras .ops .convert_to_numpy (layer .moving_std [0 ])
2019 random_input = keras .ops .convert_to_numpy (random_input )
2120 out = keras .ops .convert_to_numpy (out )
2221
2322 np .testing .assert_allclose (moving_mean , np .mean (random_input , axis = 0 ), atol = 1e-5 )
24- np .testing .assert_allclose (moving_std , np .std (random_input , axis = 0 ), atol = 1e-5 )
2523
2624 assert out .shape == random_input .shape
2725 assert not np .any (np .isnan (out ))
@@ -42,9 +40,10 @@ def test_inverse_standardization_ldj():
4240
4341def test_consistency_forward_inverse ():
4442 random_input = keras .random .normal ((4 , 20 , 5 ))
45- layer = Standardization (momentum = 0.0 )
46- layer .build ((5 ,))
47- standardized = layer (random_input , stage = "training" , forward = True )
43+ layer = Standardization ()
44+ _ = layer (random_input , stage = "training" , forward = True )
45+
46+ standardized = layer (random_input , stage = "inference" , forward = True )
4847 recovered = layer (standardized , stage = "inference" , forward = False )
4948
5049 random_input = keras .ops .convert_to_numpy (random_input )
@@ -58,9 +57,10 @@ def test_nested_consistency_forward_inverse():
5857 random_input_b = keras .random .normal ((4 , 3 ))
5958 random_input = {"a" : random_input_a , "b" : random_input_b }
6059
61- layer = Standardization (momentum = 0.0 )
60+ layer = Standardization ()
6261
63- standardized = layer (random_input , stage = "training" , forward = True )
62+ _ = layer (random_input , stage = "training" , forward = True )
63+ standardized = layer (random_input , stage = "inference" , forward = True )
6464 recovered = layer (standardized , stage = "inference" , forward = False )
6565
6666 random_input = keras .tree .map_structure (keras .ops .convert_to_numpy , random_input )
0 commit comments