1
- # Add keras.utils for the random seed
1
+
2
2
import os
3
3
4
4
import keras
16
16
17
17
class RandomElasticDeformation3DTest (TestCase ):
18
18
def test_layer_basics (self ):
19
- # --- BEST PRACTICE: Add a seed for reproducibility ---
19
+
20
20
utils .set_random_seed (0 )
21
21
layer = RandomElasticDeformation3D (
22
22
grid_size = (4 , 4 , 4 ),
@@ -32,8 +32,6 @@ def test_layer_basics(self):
32
32
self .assertEqual (label .dtype , output_label .dtype )
33
33
34
34
def test_serialization (self ):
35
- # --- BEST PRACTICE: Add a seed for reproducibility ---
36
- utils .set_random_seed (0 )
37
35
layer = RandomElasticDeformation3D (
38
36
grid_size = (3 , 3 , 3 ),
39
37
alpha = 50.0 ,
@@ -42,34 +40,39 @@ def test_serialization(self):
42
40
image_data = ops .ones ((2 , 16 , 16 , 16 , 3 ), dtype = "float32" )
43
41
label_data = ops .ones ((2 , 16 , 16 , 16 , 1 ), dtype = "int32" )
44
42
input_data = (image_data , label_data )
43
+
45
44
image_input = Input (shape = (16 , 16 , 16 , 3 ), dtype = "float32" )
46
45
label_input = Input (shape = (16 , 16 , 16 , 1 ), dtype = "int32" )
47
46
outputs = layer ((image_input , label_input ))
48
47
model = Model (inputs = [image_input , label_input ], outputs = outputs )
48
+
49
+
50
+ utils .set_random_seed (0 )
49
51
original_output_image , original_output_label = model (input_data )
50
- path = os .path .join (self .get_temp_dir (), "model.keras" )
51
52
52
- # --- FIX: Remove the deprecated save_format argument ---
53
+ path = os . path . join ( self . get_temp_dir (), "model.keras" )
53
54
model .save (path )
54
-
55
55
loaded_model = keras .models .load_model (
56
- path ,
57
- custom_objects = {
58
- "RandomElasticDeformation3D" : RandomElasticDeformation3D
59
- },
56
+ path , custom_objects = {"RandomElasticDeformation3D" : RandomElasticDeformation3D }
60
57
)
58
+
59
+
60
+ utils .set_random_seed (0 )
61
61
loaded_output_image , loaded_output_label = loaded_model (input_data )
62
+
63
+
62
64
np .testing .assert_allclose (
63
65
ops .convert_to_numpy (original_output_image ),
64
66
ops .convert_to_numpy (loaded_output_image ),
67
+ atol = 1e-6
65
68
)
66
69
np .testing .assert_array_equal (
67
70
ops .convert_to_numpy (original_output_label ),
68
71
ops .convert_to_numpy (loaded_output_label ),
69
72
)
70
73
71
74
def test_label_values_are_preserved (self ):
72
- # --- BEST PRACTICE: Add a seed for reproducibility ---
75
+
73
76
utils .set_random_seed (0 )
74
77
image = ops .zeros (shape = (1 , 16 , 16 , 16 , 1 ), dtype = "float32" )
75
78
label_arange = ops .arange (16 ** 3 , dtype = "int32" )
0 commit comments