Skip to content

Commit cb92954

Browse files
feat(layers): updated 3D elastic deformation layer
1 parent 9793ab9 commit cb92954

File tree

2 files changed

+50
-121
lines changed

2 files changed

+50
-121
lines changed
Lines changed: 38 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
1+
# Add this import
2+
from keras import backend
13
from keras import ops
24
from keras import layers
35
from keras import random
46

57
class RandomElasticDeformation3D(layers.Layer):
68
"""
79
A high-performance 3D elastic deformation layer optimized for TPUs.
8-
9-
This implementation leverages the layer's compute_dtype (e.g., bfloat16)
10-
to potentially halve memory bandwidth requirements and uses a vectorized
11-
mapping for maximum speed.
1210
"""
1311
def __init__(self,
1412
grid_size=(4, 4, 4),
@@ -17,41 +15,29 @@ def __init__(self,
1715
data_format="channels_last",
1816
**kwargs):
1917
super().__init__(**kwargs)
20-
2118
self.grid_size = grid_size
2219
self.alpha = alpha
2320
self.sigma = sigma
2421
self.data_format = data_format
2522
if data_format not in ["channels_last", "channels_first"]:
26-
raise ValueError(
27-
"`data_format` must be one of 'channels_last' or "
28-
f"'channels_first'. Received: {data_format}"
29-
)
30-
23+
raise ValueError(f"`data_format` must be one of 'channels_last' or 'channels_first'. Received: {data_format}")
24+
3125
def build(self, input_shape):
32-
"""Create tensor state in build to respect the layer's dtype."""
3326
self._alpha_tensor = ops.convert_to_tensor(self.alpha, dtype=self.compute_dtype)
3427
self._sigma_tensor = ops.convert_to_tensor(self.sigma, dtype=self.compute_dtype)
35-
36-
# Pre-compute the 1D Gaussian kernel
3728
kernel_size = ops.cast(2 * ops.round(3 * self._sigma_tensor) + 1, dtype="int32")
38-
ax = ops.arange(-ops.cast(kernel_size // 2, self.compute_dtype) + 1.0,
39-
ops.cast(kernel_size // 2, self.compute_dtype) + 1.0)
29+
ax = ops.arange(-ops.cast(kernel_size // 2, self.compute_dtype) + 1.0, ops.cast(kernel_size // 2, self.compute_dtype) + 1.0)
4030
kernel_1d = ops.exp(-(ax**2) / (2.0 * self._sigma_tensor**2))
4131
self.kernel_1d = kernel_1d / ops.sum(kernel_1d)
4232
self.built = True
4333

4434
def _separable_gaussian_filter_3d(self, tensor):
45-
"""Apply a 3D Gaussian filter using separable 1D convolutions."""
4635
depth_kernel = ops.reshape(self.kernel_1d, (-1, 1, 1, 1, 1))
4736
tensor = ops.conv(tensor, ops.cast(depth_kernel, dtype=tensor.dtype), padding='same')
48-
4937
height_kernel = ops.reshape(self.kernel_1d, (1, -1, 1, 1, 1))
5038
tensor = ops.conv(tensor, ops.cast(height_kernel, dtype=tensor.dtype), padding='same')
51-
5239
width_kernel = ops.reshape(self.kernel_1d, (1, 1, -1, 1, 1))
5340
tensor = ops.conv(tensor, ops.cast(width_kernel, dtype=tensor.dtype), padding='same')
54-
5541
return tensor
5642

5743
def call(self, inputs):
@@ -70,16 +56,10 @@ def call(self, inputs):
7056
label_volume = ops.cast(label_volume, dtype=compute_dtype)
7157

7258
input_shape = ops.shape(image_volume)
73-
B, D, H, W = input_shape[0], input_shape[1], input_shape[2], input_shape[3]
74-
C = input_shape[4]
75-
76-
# 1. Create a coarse random flow field.
77-
coarse_flow = random.uniform(
78-
shape=(B, self.grid_size[0], self.grid_size[1], self.grid_size[2], 3),
79-
minval=-1, maxval=1, dtype=compute_dtype
80-
)
81-
82-
# 2. Upsample the flow field.
59+
B, D, H, W, C = input_shape[0], input_shape[1], input_shape[2], input_shape[3], input_shape[4]
60+
61+
coarse_flow = random.uniform(shape=(B, self.grid_size[0], self.grid_size[1], self.grid_size[2], 3), minval=-1, maxval=1, dtype=compute_dtype)
62+
8363
flow = coarse_flow
8464
flow_shape = ops.shape(flow)
8565
flow = ops.reshape(flow, (flow_shape[0] * flow_shape[1], flow_shape[2], flow_shape[3], 3))
@@ -91,71 +71,49 @@ def call(self, inputs):
9171
flow = ops.image.resize(flow, (D, 1), interpolation="bicubic")
9272
flow = ops.reshape(flow, (flow_shape[0], flow_shape[1], flow_shape[2], D, 3))
9373
flow = ops.transpose(flow, (0, 3, 1, 2, 4))
94-
95-
# 3. Apply Gaussian smoothing.
74+
9675
flow_components = ops.unstack(flow, axis=-1)
9776
smoothed_components = []
9877
for component in flow_components:
99-
component_reshaped = ops.expand_dims(component, axis=-1)
100-
smoothed_component = self._separable_gaussian_filter_3d(component_reshaped)
101-
smoothed_components.append(ops.squeeze(smoothed_component, axis=-1))
78+
smoothed_components.append(ops.squeeze(self._separable_gaussian_filter_3d(ops.expand_dims(component, axis=-1)), axis=-1))
10279
smoothed_flow = ops.stack(smoothed_components, axis=-1)
10380

104-
# 4. Scale the flow field and create warp grid.
10581
flow = smoothed_flow * self._alpha_tensor
106-
grid_d, grid_h, grid_w = ops.meshgrid(
107-
ops.arange(D, dtype=compute_dtype),
108-
ops.arange(H, dtype=compute_dtype),
109-
ops.arange(W, dtype=compute_dtype),
110-
indexing='ij'
111-
)
82+
grid_d, grid_h, grid_w = ops.meshgrid(ops.arange(D, dtype=compute_dtype), ops.arange(H, dtype=compute_dtype), ops.arange(W, dtype=compute_dtype), indexing='ij')
11283
grid = ops.stack([grid_d, grid_h, grid_w], axis=-1)
11384
warp_grid = ops.expand_dims(grid, 0) + flow
11485

115-
11686
batched_coords = ops.transpose(warp_grid, (0, 4, 1, 2, 3))
11787

118-
119-
deformed_images_batched = []
120-
for i in range(B):
121-
122-
image_slice = image_volume[i]
123-
coords = batched_coords[i]
124-
125-
126-
image_slice_transposed = ops.transpose(image_slice, (3, 0, 1, 2))
127-
88+
def perform_map(elems):
89+
image_slice, label_slice, coords = elems
12890
deformed_channels = []
91+
image_slice_transposed = ops.transpose(image_slice, (3, 0, 1, 2))
92+
# The channel dimension C is a static value when the graph is built
12993
for c in range(C):
130-
131-
deformed_channel = ops.image.map_coordinates(
132-
image_slice_transposed[c], coords, order=1
133-
)
134-
deformed_channels.append(deformed_channel)
135-
136-
# Stack and transpose back to (D, H, W, C)
94+
deformed_channels.append(ops.image.map_coordinates(image_slice_transposed[c], coords, order=1))
13795
deformed_image_slice = ops.stack(deformed_channels, axis=0)
138-
deformed_images_batched.append(ops.transpose(deformed_image_slice, (1, 2, 3, 0)))
139-
140-
deformed_image = ops.stack(deformed_images_batched, axis=0)
141-
142-
# Process Labels: loop over the batch dimension.
143-
deformed_labels_batched = []
144-
for i in range(B):
145-
label_slice = label_volume[i]
146-
coords = batched_coords[i]
147-
148-
96+
deformed_image_slice = ops.transpose(deformed_image_slice, (1, 2, 3, 0))
14997
label_channel = ops.squeeze(label_slice, axis=-1)
150-
deformed_label_channel = ops.image.map_coordinates(
151-
label_channel, coords, order=0
152-
)
153-
154-
deformed_labels_batched.append(ops.expand_dims(deformed_label_channel, axis=-1))
155-
156-
deformed_label = ops.stack(deformed_labels_batched, axis=0)
157-
158-
98+
deformed_label_channel = ops.image.map_coordinates(label_channel, coords, order=0)
99+
deformed_label_slice = ops.expand_dims(deformed_label_channel, axis=-1)
100+
return deformed_image_slice, deformed_label_slice
101+
102+
if backend.backend() == "tensorflow":
103+
import tensorflow as tf
104+
deformed_image, deformed_label = tf.map_fn(perform_map, elems=(image_volume, label_volume, batched_coords), dtype=(compute_dtype, compute_dtype))
105+
elif backend.backend() == "jax":
106+
import jax
107+
deformed_image, deformed_label = jax.lax.map(perform_map, xs=(image_volume, label_volume, batched_coords))
108+
else:
109+
deformed_images_list = []
110+
deformed_labels_list = []
111+
for i in range(B):
112+
img_slice, lbl_slice = perform_map((image_volume[i], label_volume[i], batched_coords[i]))
113+
deformed_images_list.append(img_slice)
114+
deformed_labels_list.append(lbl_slice)
115+
deformed_image = ops.stack(deformed_images_list, axis=0)
116+
deformed_label = ops.stack(deformed_labels_list, axis=0)
159117

160118
deformed_image = ops.cast(deformed_image, original_image_dtype)
161119
deformed_label = ops.cast(deformed_label, original_label_dtype)
@@ -167,16 +125,10 @@ def call(self, inputs):
167125
return deformed_image, deformed_label
168126

169127
def compute_output_shape(self, input_shape):
170-
"""Computes the output shape of the layer."""
171128
image_shape, label_shape = input_shape
172129
return image_shape, label_shape
173130

174131
def get_config(self):
175132
config = super().get_config()
176-
config.update({
177-
"grid_size": self.grid_size,
178-
"alpha": self.alpha,
179-
"sigma": self.sigma,
180-
"data_format": self.data_format,
181-
})
133+
config.update({"grid_size": self.grid_size, "alpha": self.alpha, "sigma": self.sigma, "data_format": self.data_format})
182134
return config
Lines changed: 12 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
1+
# Add keras.utils for the random seed
2+
from keras import utils
23
import os
34
import numpy as np
45
import keras
@@ -10,65 +11,47 @@
1011

1112

1213
class RandomElasticDeformation3DTest(TestCase):
13-
14-
15-
16-
1714
def test_layer_basics(self):
18-
15+
# --- BEST PRACTICE: Add a seed for reproducibility ---
16+
utils.set_random_seed(0)
1917
layer = RandomElasticDeformation3D(
2018
grid_size=(4, 4, 4),
2119
alpha=10.0,
2220
sigma=2.0,
2321
)
2422
image = ops.ones((2, 32, 32, 32, 3), dtype="float32")
2523
label = ops.ones((2, 32, 32, 32, 1), dtype="int32")
26-
2724
output_image, output_label = layer((image, label))
28-
29-
# Check shapes
3025
self.assertEqual(ops.shape(image), ops.shape(output_image))
3126
self.assertEqual(ops.shape(label), ops.shape(output_label))
32-
33-
# Check dtypes
3427
self.assertEqual(image.dtype, output_image.dtype)
3528
self.assertEqual(label.dtype, output_label.dtype)
3629

37-
38-
3930
def test_serialization(self):
40-
# 1. Instantiate the layer
31+
# --- BEST PRACTICE: Add a seed for reproducibility ---
32+
utils.set_random_seed(0)
4133
layer = RandomElasticDeformation3D(
4234
grid_size=(3, 3, 3),
4335
alpha=50.0,
4436
sigma=5.0,
4537
)
46-
47-
# 2. Create dummy input data
4838
image_data = ops.ones((2, 16, 16, 16, 3), dtype="float32")
4939
label_data = ops.ones((2, 16, 16, 16, 1), dtype="int32")
5040
input_data = (image_data, label_data)
51-
52-
# 3. Build a functional Model that uses the layer
5341
image_input = Input(shape=(16, 16, 16, 3), dtype="float32")
5442
label_input = Input(shape=(16, 16, 16, 1), dtype="int32")
5543
outputs = layer((image_input, label_input))
5644
model = Model(inputs=[image_input, label_input], outputs=outputs)
57-
58-
# 4. Get the output of the original model
5945
original_output_image, original_output_label = model(input_data)
60-
61-
# 5. Save and load the model
6246
path = os.path.join(self.get_temp_dir(), "model.keras")
63-
model.save(path, save_format="keras_v3")
47+
48+
# --- FIX: Remove the deprecated save_format argument ---
49+
model.save(path)
50+
6451
loaded_model = keras.models.load_model(
6552
path, custom_objects={"RandomElasticDeformation3D": RandomElasticDeformation3D}
6653
)
67-
68-
# 6. Get the output of the loaded model
6954
loaded_output_image, loaded_output_label = loaded_model(input_data)
70-
71-
# 7. Assert that the outputs are the same
7255
np.testing.assert_allclose(
7356
ops.convert_to_numpy(original_output_image),
7457
ops.convert_to_numpy(loaded_output_image),
@@ -78,20 +61,14 @@ def test_serialization(self):
7861
ops.convert_to_numpy(loaded_output_label),
7962
)
8063

81-
82-
8364
def test_label_values_are_preserved(self):
84-
65+
# --- BEST PRACTICE: Add a seed for reproducibility ---
66+
utils.set_random_seed(0)
8567
image = ops.zeros(shape=(1, 16, 16, 16, 1), dtype="float32")
86-
87-
8868
label_arange = ops.arange(16**3, dtype="int32")
8969
label = ops.reshape(label_arange, (1, 16, 16, 16, 1)) % 4
90-
9170
layer = RandomElasticDeformation3D(alpha=80.0, sigma=8.0)
9271
_, output_label = layer((image, label))
93-
94-
9572
output_values = set(np.unique(ops.convert_to_numpy(output_label)))
9673
expected_values = {0, 1, 2, 3}
9774
self.assertLessEqual(output_values, expected_values)

0 commit comments

Comments
 (0)