Skip to content

Commit 88836c8

Browse files
feat(layers): Add 3D elastic deformation layer
1 parent fa19ac9 commit 88836c8

File tree

2 files changed

+130
-43
lines changed

2 files changed

+130
-43
lines changed
Lines changed: 122 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
11
# Add this import
22
from keras import backend
3-
from keras import ops
43
from keras import layers
4+
from keras import ops
55
from keras import random
66

7+
78
class RandomElasticDeformation3D(layers.Layer):
89
"""
910
A high-performance 3D elastic deformation layer optimized for TPUs.
1011
"""
1112

12-
def __init__(self,
13-
grid_size=(4, 4, 4),
14-
alpha=35.0,
15-
sigma=2.5,
16-
data_format="channels_last",
17-
**kwargs):
13+
def __init__(
14+
self,
15+
grid_size=(4, 4, 4),
16+
alpha=35.0,
17+
sigma=2.5,
18+
data_format="channels_last",
19+
**kwargs,
20+
):
1821
super().__init__(**kwargs)
1922
self.grid_size = grid_size
2023
self.alpha = alpha
@@ -28,21 +31,36 @@ def __init__(self,
2831
raise ValueError(message)
2932

3033
def build(self, input_shape):
31-
self._alpha_tensor = ops.convert_to_tensor(self.alpha, dtype=self.compute_dtype)
32-
self._sigma_tensor = ops.convert_to_tensor(self.sigma, dtype=self.compute_dtype)
33-
kernel_size = ops.cast(2 * ops.round(3 * self._sigma_tensor) + 1, dtype="int32")
34-
ax = ops.arange(-ops.cast(kernel_size // 2, self.compute_dtype) + 1.0, ops.cast(kernel_size // 2, self.compute_dtype) + 1.0)
34+
self._alpha_tensor = ops.convert_to_tensor(
35+
self.alpha, dtype=self.compute_dtype
36+
)
37+
self._sigma_tensor = ops.convert_to_tensor(
38+
self.sigma, dtype=self.compute_dtype
39+
)
40+
kernel_size = ops.cast(
41+
2 * ops.round(3 * self._sigma_tensor) + 1, dtype="int32"
42+
)
43+
ax = ops.arange(
44+
-ops.cast(kernel_size // 2, self.compute_dtype) + 1.0,
45+
ops.cast(kernel_size // 2, self.compute_dtype) + 1.0,
46+
)
3547
kernel_1d = ops.exp(-(ax**2) / (2.0 * self._sigma_tensor**2))
3648
self.kernel_1d = kernel_1d / ops.sum(kernel_1d)
3749
self.built = True
3850

3951
def _separable_gaussian_filter_3d(self, tensor):
4052
depth_kernel = ops.reshape(self.kernel_1d, (-1, 1, 1, 1, 1))
41-
tensor = ops.conv(tensor, ops.cast(depth_kernel, dtype=tensor.dtype), padding='same')
53+
tensor = ops.conv(
54+
tensor, ops.cast(depth_kernel, dtype=tensor.dtype), padding="same"
55+
)
4256
height_kernel = ops.reshape(self.kernel_1d, (1, -1, 1, 1, 1))
43-
tensor = ops.conv(tensor, ops.cast(height_kernel, dtype=tensor.dtype), padding='same')
57+
tensor = ops.conv(
58+
tensor, ops.cast(height_kernel, dtype=tensor.dtype), padding="same"
59+
)
4460
width_kernel = ops.reshape(self.kernel_1d, (1, 1, -1, 1, 1))
45-
tensor = ops.conv(tensor, ops.cast(width_kernel, dtype=tensor.dtype), padding='same')
61+
tensor = ops.conv(
62+
tensor, ops.cast(width_kernel, dtype=tensor.dtype), padding="same"
63+
)
4664
return tensor
4765

4866
def call(self, inputs):
@@ -61,33 +79,75 @@ def call(self, inputs):
6179
label_volume = ops.cast(label_volume, dtype=compute_dtype)
6280

6381
input_shape = ops.shape(image_volume)
64-
B, D, H, W, C = input_shape[0], input_shape[1], input_shape[2], input_shape[3], input_shape[4]
65-
66-
coarse_flow = random.uniform(shape=(B, self.grid_size[0], self.grid_size[1], self.grid_size[2], 3), minval=-1, maxval=1, dtype=compute_dtype)
67-
82+
B, D, H, W, C = (
83+
input_shape[0],
84+
input_shape[1],
85+
input_shape[2],
86+
input_shape[3],
87+
input_shape[4],
88+
)
89+
90+
coarse_flow = random.uniform(
91+
shape=(
92+
B,
93+
self.grid_size[0],
94+
self.grid_size[1],
95+
self.grid_size[2],
96+
3,
97+
),
98+
minval=-1,
99+
maxval=1,
100+
dtype=compute_dtype,
101+
)
102+
68103
flow = coarse_flow
69104
flow_shape = ops.shape(flow)
70-
flow = ops.reshape(flow, (flow_shape[0] * flow_shape[1], flow_shape[2], flow_shape[3], 3))
105+
flow = ops.reshape(
106+
flow,
107+
(flow_shape[0] * flow_shape[1], flow_shape[2], flow_shape[3], 3),
108+
)
71109
flow = ops.image.resize(flow, (H, W), interpolation="bicubic")
72110
flow = ops.reshape(flow, (flow_shape[0], flow_shape[1], H, W, 3))
73111
flow = ops.transpose(flow, (0, 2, 3, 1, 4))
74112
flow_shape = ops.shape(flow)
75-
flow = ops.reshape(flow, (flow_shape[0] * flow_shape[1] * flow_shape[2], flow_shape[3], 1, 3))
113+
flow = ops.reshape(
114+
flow,
115+
(
116+
flow_shape[0] * flow_shape[1] * flow_shape[2],
117+
flow_shape[3],
118+
1,
119+
3,
120+
),
121+
)
76122
flow = ops.image.resize(flow, (D, 1), interpolation="bicubic")
77-
flow = ops.reshape(flow, (flow_shape[0], flow_shape[1], flow_shape[2], D, 3))
123+
flow = ops.reshape(
124+
flow, (flow_shape[0], flow_shape[1], flow_shape[2], D, 3)
125+
)
78126
flow = ops.transpose(flow, (0, 3, 1, 2, 4))
79-
127+
80128
flow_components = ops.unstack(flow, axis=-1)
81129
smoothed_components = []
82130
for component in flow_components:
83-
smoothed_components.append(ops.squeeze(self._separable_gaussian_filter_3d(ops.expand_dims(component, axis=-1)), axis=-1))
131+
smoothed_components.append(
132+
ops.squeeze(
133+
self._separable_gaussian_filter_3d(
134+
ops.expand_dims(component, axis=-1)
135+
),
136+
axis=-1,
137+
)
138+
)
84139
smoothed_flow = ops.stack(smoothed_components, axis=-1)
85-
140+
86141
flow = smoothed_flow * self._alpha_tensor
87-
grid_d, grid_h, grid_w = ops.meshgrid(ops.arange(D, dtype=compute_dtype), ops.arange(H, dtype=compute_dtype), ops.arange(W, dtype=compute_dtype), indexing='ij')
142+
grid_d, grid_h, grid_w = ops.meshgrid(
143+
ops.arange(D, dtype=compute_dtype),
144+
ops.arange(H, dtype=compute_dtype),
145+
ops.arange(W, dtype=compute_dtype),
146+
indexing="ij",
147+
)
88148
grid = ops.stack([grid_d, grid_h, grid_w], axis=-1)
89149
warp_grid = ops.expand_dims(grid, 0) + flow
90-
150+
91151
batched_coords = ops.transpose(warp_grid, (0, 4, 1, 2, 3))
92152

93153
def perform_map(elems):
@@ -96,25 +156,45 @@ def perform_map(elems):
96156
image_slice_transposed = ops.transpose(image_slice, (3, 0, 1, 2))
97157
# The channel dimension C is a static value when the graph is built
98158
for c in range(C):
99-
deformed_channels.append(ops.image.map_coordinates(image_slice_transposed[c], coords, order=1))
159+
deformed_channels.append(
160+
ops.image.map_coordinates(
161+
image_slice_transposed[c], coords, order=1
162+
)
163+
)
100164
deformed_image_slice = ops.stack(deformed_channels, axis=0)
101-
deformed_image_slice = ops.transpose(deformed_image_slice, (1, 2, 3, 0))
165+
deformed_image_slice = ops.transpose(
166+
deformed_image_slice, (1, 2, 3, 0)
167+
)
102168
label_channel = ops.squeeze(label_slice, axis=-1)
103-
deformed_label_channel = ops.image.map_coordinates(label_channel, coords, order=0)
104-
deformed_label_slice = ops.expand_dims(deformed_label_channel, axis=-1)
169+
deformed_label_channel = ops.image.map_coordinates(
170+
label_channel, coords, order=0
171+
)
172+
deformed_label_slice = ops.expand_dims(
173+
deformed_label_channel, axis=-1
174+
)
105175
return deformed_image_slice, deformed_label_slice
106176

107177
if backend.backend() == "tensorflow":
108178
import tensorflow as tf
109-
deformed_image, deformed_label = tf.map_fn(perform_map, elems=(image_volume, label_volume, batched_coords), dtype=(compute_dtype, compute_dtype))
179+
180+
deformed_image, deformed_label = tf.map_fn(
181+
perform_map,
182+
elems=(image_volume, label_volume, batched_coords),
183+
dtype=(compute_dtype, compute_dtype),
184+
)
110185
elif backend.backend() == "jax":
111186
import jax
112-
deformed_image, deformed_label = jax.lax.map(perform_map, xs=(image_volume, label_volume, batched_coords))
187+
188+
deformed_image, deformed_label = jax.lax.map(
189+
perform_map, xs=(image_volume, label_volume, batched_coords)
190+
)
113191
else:
114192
deformed_images_list = []
115193
deformed_labels_list = []
116194
for i in range(B):
117-
img_slice, lbl_slice = perform_map((image_volume[i], label_volume[i], batched_coords[i]))
195+
img_slice, lbl_slice = perform_map(
196+
(image_volume[i], label_volume[i], batched_coords[i])
197+
)
118198
deformed_images_list.append(img_slice)
119199
deformed_labels_list.append(lbl_slice)
120200
deformed_image = ops.stack(deformed_images_list, axis=0)
@@ -135,5 +215,12 @@ def compute_output_shape(self, input_shape):
135215

136216
def get_config(self):
137217
config = super().get_config()
138-
config.update({"grid_size": self.grid_size, "alpha": self.alpha, "sigma": self.sigma, "data_format": self.data_format})
139-
return config
218+
config.update(
219+
{
220+
"grid_size": self.grid_size,
221+
"alpha": self.alpha,
222+
"sigma": self.sigma,
223+
"data_format": self.data_format,
224+
}
225+
)
226+
return config

keras_hub/src/layers/preprocessing/random_elastic_deformation_3d_test.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# Add keras.utils for the random seed
21
import os
32

43
import keras
@@ -16,7 +15,6 @@
1615

1716
class RandomElasticDeformation3DTest(TestCase):
1817
def test_layer_basics(self):
19-
# --- BEST PRACTICE: Add a seed for reproducibility ---
2018
utils.set_random_seed(0)
2119
layer = RandomElasticDeformation3D(
2220
grid_size=(4, 4, 4),
@@ -32,8 +30,6 @@ def test_layer_basics(self):
3230
self.assertEqual(label.dtype, output_label.dtype)
3331

3432
def test_serialization(self):
35-
# --- BEST PRACTICE: Add a seed for reproducibility ---
36-
utils.set_random_seed(0)
3733
layer = RandomElasticDeformation3D(
3834
grid_size=(3, 3, 3),
3935
alpha=50.0,
@@ -42,34 +38,38 @@ def test_serialization(self):
4238
image_data = ops.ones((2, 16, 16, 16, 3), dtype="float32")
4339
label_data = ops.ones((2, 16, 16, 16, 1), dtype="int32")
4440
input_data = (image_data, label_data)
41+
4542
image_input = Input(shape=(16, 16, 16, 3), dtype="float32")
4643
label_input = Input(shape=(16, 16, 16, 1), dtype="int32")
4744
outputs = layer((image_input, label_input))
4845
model = Model(inputs=[image_input, label_input], outputs=outputs)
46+
47+
utils.set_random_seed(0)
4948
original_output_image, original_output_label = model(input_data)
50-
path = os.path.join(self.get_temp_dir(), "model.keras")
5149

52-
# --- FIX: Remove the deprecated save_format argument ---
50+
path = os.path.join(self.get_temp_dir(), "model.keras")
5351
model.save(path)
54-
5552
loaded_model = keras.models.load_model(
5653
path,
5754
custom_objects={
5855
"RandomElasticDeformation3D": RandomElasticDeformation3D
5956
},
6057
)
58+
59+
utils.set_random_seed(0)
6160
loaded_output_image, loaded_output_label = loaded_model(input_data)
61+
6262
np.testing.assert_allclose(
6363
ops.convert_to_numpy(original_output_image),
6464
ops.convert_to_numpy(loaded_output_image),
65+
atol=1e-6,
6566
)
6667
np.testing.assert_array_equal(
6768
ops.convert_to_numpy(original_output_label),
6869
ops.convert_to_numpy(loaded_output_label),
6970
)
7071

7172
def test_label_values_are_preserved(self):
72-
# --- BEST PRACTICE: Add a seed for reproducibility ---
7373
utils.set_random_seed(0)
7474
image = ops.zeros(shape=(1, 16, 16, 16, 1), dtype="float32")
7575
label_arange = ops.arange(16**3, dtype="int32")

0 commit comments

Comments
 (0)