Skip to content

Commit 4cf6ee0

Browse files
feat(layers): Add 3D elastic deformation layer
1 parent fa19ac9 commit 4cf6ee0

File tree

2 files changed

+149
-46
lines changed

2 files changed

+149
-46
lines changed
Lines changed: 141 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,31 @@
11
# Add this import
22
from keras import backend
3-
from keras import ops
43
from keras import layers
4+
from keras import ops
55
from keras import random
66

7+
78
class RandomElasticDeformation3D(layers.Layer):
89
"""
910
A high-performance 3D elastic deformation layer optimized for TPUs.
1011
"""
1112

12-
def __init__(self,
13-
grid_size=(4, 4, 4),
14-
alpha=35.0,
15-
sigma=2.5,
16-
data_format="channels_last",
17-
**kwargs):
13+
def __init__(
14+
self,
15+
grid_size=(4, 4, 4),
16+
alpha=35.0,
17+
sigma=2.5,
18+
data_format="channels_last",
19+
seed=None,
20+
**kwargs,
21+
):
1822
super().__init__(**kwargs)
1923
self.grid_size = grid_size
24+
self.seed = seed
2025
self.alpha = alpha
2126
self.sigma = sigma
2227
self.data_format = data_format
28+
self._rng = random.SeedGenerator(seed) if seed is not None else None
2329
if data_format not in ["channels_last", "channels_first"]:
2430
message = (
2531
"`data_format` must be one of 'channels_last' or "
@@ -28,21 +34,36 @@ def __init__(self,
2834
raise ValueError(message)
2935

3036
def build(self, input_shape):
31-
self._alpha_tensor = ops.convert_to_tensor(self.alpha, dtype=self.compute_dtype)
32-
self._sigma_tensor = ops.convert_to_tensor(self.sigma, dtype=self.compute_dtype)
33-
kernel_size = ops.cast(2 * ops.round(3 * self._sigma_tensor) + 1, dtype="int32")
34-
ax = ops.arange(-ops.cast(kernel_size // 2, self.compute_dtype) + 1.0, ops.cast(kernel_size // 2, self.compute_dtype) + 1.0)
37+
self._alpha_tensor = ops.convert_to_tensor(
38+
self.alpha, dtype=self.compute_dtype
39+
)
40+
self._sigma_tensor = ops.convert_to_tensor(
41+
self.sigma, dtype=self.compute_dtype
42+
)
43+
kernel_size = ops.cast(
44+
2 * ops.round(3 * self._sigma_tensor) + 1, dtype="int32"
45+
)
46+
ax = ops.arange(
47+
-ops.cast(kernel_size // 2, self.compute_dtype) + 1.0,
48+
ops.cast(kernel_size // 2, self.compute_dtype) + 1.0,
49+
)
3550
kernel_1d = ops.exp(-(ax**2) / (2.0 * self._sigma_tensor**2))
3651
self.kernel_1d = kernel_1d / ops.sum(kernel_1d)
3752
self.built = True
3853

3954
def _separable_gaussian_filter_3d(self, tensor):
4055
depth_kernel = ops.reshape(self.kernel_1d, (-1, 1, 1, 1, 1))
41-
tensor = ops.conv(tensor, ops.cast(depth_kernel, dtype=tensor.dtype), padding='same')
56+
tensor = ops.conv(
57+
tensor, ops.cast(depth_kernel, dtype=tensor.dtype), padding="same"
58+
)
4259
height_kernel = ops.reshape(self.kernel_1d, (1, -1, 1, 1, 1))
43-
tensor = ops.conv(tensor, ops.cast(height_kernel, dtype=tensor.dtype), padding='same')
60+
tensor = ops.conv(
61+
tensor, ops.cast(height_kernel, dtype=tensor.dtype), padding="same"
62+
)
4463
width_kernel = ops.reshape(self.kernel_1d, (1, 1, -1, 1, 1))
45-
tensor = ops.conv(tensor, ops.cast(width_kernel, dtype=tensor.dtype), padding='same')
64+
tensor = ops.conv(
65+
tensor, ops.cast(width_kernel, dtype=tensor.dtype), padding="same"
66+
)
4667
return tensor
4768

4869
def call(self, inputs):
@@ -61,33 +82,90 @@ def call(self, inputs):
6182
label_volume = ops.cast(label_volume, dtype=compute_dtype)
6283

6384
input_shape = ops.shape(image_volume)
64-
B, D, H, W, C = input_shape[0], input_shape[1], input_shape[2], input_shape[3], input_shape[4]
65-
66-
coarse_flow = random.uniform(shape=(B, self.grid_size[0], self.grid_size[1], self.grid_size[2], 3), minval=-1, maxval=1, dtype=compute_dtype)
67-
85+
B, D, H, W, C = (
86+
input_shape[0],
87+
input_shape[1],
88+
input_shape[2],
89+
input_shape[3],
90+
input_shape[4],
91+
)
92+
93+
if self._rng is not None:
94+
coarse_flow = random.uniform(
95+
shape=(
96+
B,
97+
self.grid_size[0],
98+
self.grid_size[1],
99+
self.grid_size[2],
100+
3,
101+
),
102+
minval=-1,
103+
maxval=1,
104+
dtype=compute_dtype,
105+
seed=self._rng,
106+
)
107+
else:
108+
coarse_flow = random.uniform(
109+
shape=(
110+
B,
111+
self.grid_size[0],
112+
self.grid_size[1],
113+
self.grid_size[2],
114+
3,
115+
),
116+
minval=-1,
117+
maxval=1,
118+
dtype=compute_dtype,
119+
)
120+
68121
flow = coarse_flow
69122
flow_shape = ops.shape(flow)
70-
flow = ops.reshape(flow, (flow_shape[0] * flow_shape[1], flow_shape[2], flow_shape[3], 3))
123+
flow = ops.reshape(
124+
flow,
125+
(flow_shape[0] * flow_shape[1], flow_shape[2], flow_shape[3], 3),
126+
)
71127
flow = ops.image.resize(flow, (H, W), interpolation="bicubic")
72128
flow = ops.reshape(flow, (flow_shape[0], flow_shape[1], H, W, 3))
73129
flow = ops.transpose(flow, (0, 2, 3, 1, 4))
74130
flow_shape = ops.shape(flow)
75-
flow = ops.reshape(flow, (flow_shape[0] * flow_shape[1] * flow_shape[2], flow_shape[3], 1, 3))
131+
flow = ops.reshape(
132+
flow,
133+
(
134+
flow_shape[0] * flow_shape[1] * flow_shape[2],
135+
flow_shape[3],
136+
1,
137+
3,
138+
),
139+
)
76140
flow = ops.image.resize(flow, (D, 1), interpolation="bicubic")
77-
flow = ops.reshape(flow, (flow_shape[0], flow_shape[1], flow_shape[2], D, 3))
141+
flow = ops.reshape(
142+
flow, (flow_shape[0], flow_shape[1], flow_shape[2], D, 3)
143+
)
78144
flow = ops.transpose(flow, (0, 3, 1, 2, 4))
79-
145+
80146
flow_components = ops.unstack(flow, axis=-1)
81147
smoothed_components = []
82148
for component in flow_components:
83-
smoothed_components.append(ops.squeeze(self._separable_gaussian_filter_3d(ops.expand_dims(component, axis=-1)), axis=-1))
149+
smoothed_components.append(
150+
ops.squeeze(
151+
self._separable_gaussian_filter_3d(
152+
ops.expand_dims(component, axis=-1)
153+
),
154+
axis=-1,
155+
)
156+
)
84157
smoothed_flow = ops.stack(smoothed_components, axis=-1)
85-
158+
86159
flow = smoothed_flow * self._alpha_tensor
87-
grid_d, grid_h, grid_w = ops.meshgrid(ops.arange(D, dtype=compute_dtype), ops.arange(H, dtype=compute_dtype), ops.arange(W, dtype=compute_dtype), indexing='ij')
160+
grid_d, grid_h, grid_w = ops.meshgrid(
161+
ops.arange(D, dtype=compute_dtype),
162+
ops.arange(H, dtype=compute_dtype),
163+
ops.arange(W, dtype=compute_dtype),
164+
indexing="ij",
165+
)
88166
grid = ops.stack([grid_d, grid_h, grid_w], axis=-1)
89167
warp_grid = ops.expand_dims(grid, 0) + flow
90-
168+
91169
batched_coords = ops.transpose(warp_grid, (0, 4, 1, 2, 3))
92170

93171
def perform_map(elems):
@@ -96,25 +174,45 @@ def perform_map(elems):
96174
image_slice_transposed = ops.transpose(image_slice, (3, 0, 1, 2))
97175
# The channel dimension C is a static value when the graph is built
98176
for c in range(C):
99-
deformed_channels.append(ops.image.map_coordinates(image_slice_transposed[c], coords, order=1))
177+
deformed_channels.append(
178+
ops.image.map_coordinates(
179+
image_slice_transposed[c], coords, order=1
180+
)
181+
)
100182
deformed_image_slice = ops.stack(deformed_channels, axis=0)
101-
deformed_image_slice = ops.transpose(deformed_image_slice, (1, 2, 3, 0))
183+
deformed_image_slice = ops.transpose(
184+
deformed_image_slice, (1, 2, 3, 0)
185+
)
102186
label_channel = ops.squeeze(label_slice, axis=-1)
103-
deformed_label_channel = ops.image.map_coordinates(label_channel, coords, order=0)
104-
deformed_label_slice = ops.expand_dims(deformed_label_channel, axis=-1)
187+
deformed_label_channel = ops.image.map_coordinates(
188+
label_channel, coords, order=0
189+
)
190+
deformed_label_slice = ops.expand_dims(
191+
deformed_label_channel, axis=-1
192+
)
105193
return deformed_image_slice, deformed_label_slice
106194

107195
if backend.backend() == "tensorflow":
108196
import tensorflow as tf
109-
deformed_image, deformed_label = tf.map_fn(perform_map, elems=(image_volume, label_volume, batched_coords), dtype=(compute_dtype, compute_dtype))
197+
198+
deformed_image, deformed_label = tf.map_fn(
199+
perform_map,
200+
elems=(image_volume, label_volume, batched_coords),
201+
dtype=(compute_dtype, compute_dtype),
202+
)
110203
elif backend.backend() == "jax":
111204
import jax
112-
deformed_image, deformed_label = jax.lax.map(perform_map, xs=(image_volume, label_volume, batched_coords))
205+
206+
deformed_image, deformed_label = jax.lax.map(
207+
perform_map, xs=(image_volume, label_volume, batched_coords)
208+
)
113209
else:
114210
deformed_images_list = []
115211
deformed_labels_list = []
116212
for i in range(B):
117-
img_slice, lbl_slice = perform_map((image_volume[i], label_volume[i], batched_coords[i]))
213+
img_slice, lbl_slice = perform_map(
214+
(image_volume[i], label_volume[i], batched_coords[i])
215+
)
118216
deformed_images_list.append(img_slice)
119217
deformed_labels_list.append(lbl_slice)
120218
deformed_image = ops.stack(deformed_images_list, axis=0)
@@ -135,5 +233,13 @@ def compute_output_shape(self, input_shape):
135233

136234
def get_config(self):
137235
config = super().get_config()
138-
config.update({"grid_size": self.grid_size, "alpha": self.alpha, "sigma": self.sigma, "data_format": self.data_format})
139-
return config
236+
config.update(
237+
{
238+
"grid_size": self.grid_size,
239+
"alpha": self.alpha,
240+
"sigma": self.sigma,
241+
"data_format": self.data_format,
242+
"seed": self.seed,
243+
}
244+
)
245+
return config

keras_hub/src/layers/preprocessing/random_elastic_deformation_3d_test.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# Add keras.utils for the random seed
21
import os
32

43
import keras
@@ -16,12 +15,9 @@
1615

1716
class RandomElasticDeformation3DTest(TestCase):
1817
def test_layer_basics(self):
19-
# --- BEST PRACTICE: Add a seed for reproducibility ---
2018
utils.set_random_seed(0)
2119
layer = RandomElasticDeformation3D(
22-
grid_size=(4, 4, 4),
23-
alpha=10.0,
24-
sigma=2.0,
20+
grid_size=(4, 4, 4), alpha=10.0, sigma=2.0, seed=0
2521
)
2622
image = ops.ones((2, 32, 32, 32, 3), dtype="float32")
2723
label = ops.ones((2, 32, 32, 32, 1), dtype="int32")
@@ -32,44 +28,45 @@ def test_layer_basics(self):
3228
self.assertEqual(label.dtype, output_label.dtype)
3329

3430
def test_serialization(self):
35-
# --- BEST PRACTICE: Add a seed for reproducibility ---
36-
utils.set_random_seed(0)
3731
layer = RandomElasticDeformation3D(
3832
grid_size=(3, 3, 3),
3933
alpha=50.0,
4034
sigma=5.0,
35+
seed=0,
4136
)
4237
image_data = ops.ones((2, 16, 16, 16, 3), dtype="float32")
4338
label_data = ops.ones((2, 16, 16, 16, 1), dtype="int32")
4439
input_data = (image_data, label_data)
40+
4541
image_input = Input(shape=(16, 16, 16, 3), dtype="float32")
4642
label_input = Input(shape=(16, 16, 16, 1), dtype="int32")
4743
outputs = layer((image_input, label_input))
4844
model = Model(inputs=[image_input, label_input], outputs=outputs)
45+
4946
original_output_image, original_output_label = model(input_data)
50-
path = os.path.join(self.get_temp_dir(), "model.keras")
5147

52-
# --- FIX: Remove the deprecated save_format argument ---
48+
path = os.path.join(self.get_temp_dir(), "model.keras")
5349
model.save(path)
54-
5550
loaded_model = keras.models.load_model(
5651
path,
5752
custom_objects={
5853
"RandomElasticDeformation3D": RandomElasticDeformation3D
5954
},
6055
)
56+
6157
loaded_output_image, loaded_output_label = loaded_model(input_data)
58+
6259
np.testing.assert_allclose(
6360
ops.convert_to_numpy(original_output_image),
6461
ops.convert_to_numpy(loaded_output_image),
62+
atol=1e-6,
6563
)
6664
np.testing.assert_array_equal(
6765
ops.convert_to_numpy(original_output_label),
6866
ops.convert_to_numpy(loaded_output_label),
6967
)
7068

7169
def test_label_values_are_preserved(self):
72-
# --- BEST PRACTICE: Add a seed for reproducibility ---
7370
utils.set_random_seed(0)
7471
image = ops.zeros(shape=(1, 16, 16, 16, 1), dtype="float32")
7572
label_arange = ops.arange(16**3, dtype="int32")

0 commit comments

Comments
 (0)