Skip to content

Commit dc925ce

Browse files
feat(layers): Add 3D elastic deformation layer
1 parent fa19ac9 commit dc925ce

File tree

2 files changed

+16
-12
lines changed

2 files changed

+16
-12
lines changed

keras_hub/src/layers/preprocessing/random_elastic_deformation_3d.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def __init__(self,
2121
self.sigma = sigma
2222
self.data_format = data_format
2323
if data_format not in ["channels_last", "channels_first"]:
24+
2425
message = (
2526
"`data_format` must be one of 'channels_last' or "
2627
f"'channels_first'. Received: {self.data_format}"

keras_hub/src/layers/preprocessing/random_elastic_deformation_3d_test.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Add keras.utils for the random seed
1+
22
import os
33

44
import keras
@@ -16,7 +16,7 @@
1616

1717
class RandomElasticDeformation3DTest(TestCase):
1818
def test_layer_basics(self):
19-
# --- BEST PRACTICE: Add a seed for reproducibility ---
19+
2020
utils.set_random_seed(0)
2121
layer = RandomElasticDeformation3D(
2222
grid_size=(4, 4, 4),
@@ -32,8 +32,6 @@ def test_layer_basics(self):
3232
self.assertEqual(label.dtype, output_label.dtype)
3333

3434
def test_serialization(self):
35-
# --- BEST PRACTICE: Add a seed for reproducibility ---
36-
utils.set_random_seed(0)
3735
layer = RandomElasticDeformation3D(
3836
grid_size=(3, 3, 3),
3937
alpha=50.0,
@@ -42,34 +40,39 @@ def test_serialization(self):
4240
image_data = ops.ones((2, 16, 16, 16, 3), dtype="float32")
4341
label_data = ops.ones((2, 16, 16, 16, 1), dtype="int32")
4442
input_data = (image_data, label_data)
43+
4544
image_input = Input(shape=(16, 16, 16, 3), dtype="float32")
4645
label_input = Input(shape=(16, 16, 16, 1), dtype="int32")
4746
outputs = layer((image_input, label_input))
4847
model = Model(inputs=[image_input, label_input], outputs=outputs)
48+
49+
50+
utils.set_random_seed(0)
4951
original_output_image, original_output_label = model(input_data)
50-
path = os.path.join(self.get_temp_dir(), "model.keras")
5152

52-
# --- FIX: Remove the deprecated save_format argument ---
53+
path = os.path.join(self.get_temp_dir(), "model.keras")
5354
model.save(path)
54-
5555
loaded_model = keras.models.load_model(
56-
path,
57-
custom_objects={
58-
"RandomElasticDeformation3D": RandomElasticDeformation3D
59-
},
56+
path, custom_objects={"RandomElasticDeformation3D": RandomElasticDeformation3D}
6057
)
58+
59+
60+
utils.set_random_seed(0)
6161
loaded_output_image, loaded_output_label = loaded_model(input_data)
62+
63+
6264
np.testing.assert_allclose(
6365
ops.convert_to_numpy(original_output_image),
6466
ops.convert_to_numpy(loaded_output_image),
67+
atol=1e-6
6568
)
6669
np.testing.assert_array_equal(
6770
ops.convert_to_numpy(original_output_label),
6871
ops.convert_to_numpy(loaded_output_label),
6972
)
7073

7174
def test_label_values_are_preserved(self):
72-
# --- BEST PRACTICE: Add a seed for reproducibility ---
75+
7376
utils.set_random_seed(0)
7477
image = ops.zeros(shape=(1, 16, 16, 16, 1), dtype="float32")
7578
label_arange = ops.arange(16**3, dtype="int32")

0 commit comments

Comments
 (0)