Skip to content

Commit fa19ac9

Browse files
feat(layers): Add 3D elastic deformation layer
1 parent cb92954 commit fa19ac9

File tree

3 files changed

+21
-13
lines changed

3 files changed

+21
-13
lines changed

keras_hub/api/layers/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,3 @@
147147
from keras_hub.src.models.xception.xception_image_converter import (
148148
XceptionImageConverter as XceptionImageConverter,
149149
)
150-
151-
from keras_hub.src.layers.preprocessing.random_elastic_deformation_3d import (
152-
RandomElasticDeformation3D,
153-
)

keras_hub/src/layers/preprocessing/random_elastic_deformation_3d.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ class RandomElasticDeformation3D(layers.Layer):
88
"""
99
A high-performance 3D elastic deformation layer optimized for TPUs.
1010
"""
11+
1112
def __init__(self,
1213
grid_size=(4, 4, 4),
1314
alpha=35.0,
@@ -20,7 +21,11 @@ def __init__(self,
2021
self.sigma = sigma
2122
self.data_format = data_format
2223
if data_format not in ["channels_last", "channels_first"]:
23-
raise ValueError(f"`data_format` must be one of 'channels_last' or 'channels_first'. Received: {data_format}")
24+
message = (
25+
"`data_format` must be one of 'channels_last' or "
26+
f"'channels_first'. Received: {self.data_format}"
27+
)
28+
raise ValueError(message)
2429

2530
def build(self, input_shape):
2631
self._alpha_tensor = ops.convert_to_tensor(self.alpha, dtype=self.compute_dtype)

keras_hub/src/layers/preprocessing/random_elastic_deformation_3d_test.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
# Add keras.utils for the random seed
2-
from keras import utils
32
import os
4-
import numpy as np
3+
54
import keras
6-
from keras import Model
5+
import numpy as np
76
from keras import Input
7+
from keras import Model
88
from keras import ops
9-
from keras_hub.src.layers.preprocessing.random_elastic_deformation_3d import RandomElasticDeformation3D
9+
from keras import utils
10+
11+
from keras_hub.src.layers.preprocessing.random_elastic_deformation_3d import (
12+
RandomElasticDeformation3D,
13+
)
1014
from keras_hub.src.tests.test_case import TestCase
1115

1216

@@ -44,12 +48,15 @@ def test_serialization(self):
4448
model = Model(inputs=[image_input, label_input], outputs=outputs)
4549
original_output_image, original_output_label = model(input_data)
4650
path = os.path.join(self.get_temp_dir(), "model.keras")
47-
51+
4852
# --- FIX: Remove the deprecated save_format argument ---
4953
model.save(path)
50-
54+
5155
loaded_model = keras.models.load_model(
52-
path, custom_objects={"RandomElasticDeformation3D": RandomElasticDeformation3D}
56+
path,
57+
custom_objects={
58+
"RandomElasticDeformation3D": RandomElasticDeformation3D
59+
},
5360
)
5461
loaded_output_image, loaded_output_label = loaded_model(input_data)
5562
np.testing.assert_allclose(
@@ -71,4 +78,4 @@ def test_label_values_are_preserved(self):
7178
_, output_label = layer((image, label))
7279
output_values = set(np.unique(ops.convert_to_numpy(output_label)))
7380
expected_values = {0, 1, 2, 3}
74-
self.assertLessEqual(output_values, expected_values)
81+
self.assertLessEqual(output_values, expected_values)

0 commit comments

Comments
 (0)