Skip to content

Commit 9793ab9

Browse files
Fix: Address all review feedback and resolve test failures
1 parent 3665ebb commit 9793ab9

File tree

2 files changed

+235
-153
lines changed

2 files changed

+235
-153
lines changed
Lines changed: 147 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,127 +1,182 @@
1-
import tensorflow as tf
1+
from keras import ops
2+
from keras import layers
3+
from keras import random
24

3-
class RandomElasticDeformation3D(tf.keras.layers.Layer):
5+
class RandomElasticDeformation3D(layers.Layer):
46
"""
5-
A high-performance 3D elastic deformation layer optimized for TPUs and GPUs.
6-
... (docstring is the same) ...
7+
A high-performance 3D elastic deformation layer optimized for TPUs.
8+
9+
This implementation leverages the layer's compute_dtype (e.g., bfloat16)
10+
to potentially halve memory bandwidth requirements and uses a vectorized
11+
mapping for maximum speed.
712
"""
813
def __init__(self,
914
grid_size=(4, 4, 4),
1015
alpha=35.0,
1116
sigma=2.5,
12-
data_format="DHWC",
17+
data_format="channels_last",
1318
**kwargs):
1419
super().__init__(**kwargs)
20+
1521
self.grid_size = grid_size
16-
self.alpha = tf.constant(alpha, dtype=tf.bfloat16)
17-
self.sigma = tf.constant(sigma, dtype=tf.bfloat16)
18-
if data_format not in ["DHWC", "HWDC"]:
19-
raise ValueError("`data_format` must be one of 'DHWC' or 'HWDC'")
22+
self.alpha = alpha
23+
self.sigma = sigma
2024
self.data_format = data_format
21-
22-
def _separable_gaussian_filter_3d(self, tensor, sigma):
23-
24-
kernel_size = tf.cast(2 * tf.round(3 * sigma) + 1, dtype=tf.int32)
25-
ax = tf.range(-tf.cast(kernel_size // 2, tf.bfloat16) + 1.0,
26-
tf.cast(kernel_size // 2, tf.bfloat16) + 1.0)
27-
kernel_1d = tf.exp(-(ax**2) / (2.0 * self.sigma**2))
28-
kernel_1d = kernel_1d / tf.reduce_sum(kernel_1d)
29-
filter_d = tf.cast(tf.reshape(kernel_1d, [-1, 1, 1, 1, 1]), dtype=tensor.dtype)
30-
filter_h = tf.cast(tf.reshape(kernel_1d, [1, -1, 1, 1, 1]), dtype=tensor.dtype)
31-
filter_w = tf.cast(tf.reshape(kernel_1d, [1, 1, -1, 1, 1]), dtype=tensor.dtype)
32-
tensor = tf.nn.convolution(tensor, filter_d, strides=1, padding='SAME')
33-
tensor = tf.nn.convolution(tensor, filter_h, strides=1, padding='SAME')
34-
tensor = tf.nn.convolution(tensor, filter_w, strides=1, padding='SAME')
25+
if data_format not in ["channels_last", "channels_first"]:
26+
raise ValueError(
27+
"`data_format` must be one of 'channels_last' or "
28+
f"'channels_first'. Received: {data_format}"
29+
)
30+
31+
def build(self, input_shape):
32+
"""Create tensor state in build to respect the layer's dtype."""
33+
self._alpha_tensor = ops.convert_to_tensor(self.alpha, dtype=self.compute_dtype)
34+
self._sigma_tensor = ops.convert_to_tensor(self.sigma, dtype=self.compute_dtype)
35+
36+
# Pre-compute the 1D Gaussian kernel
37+
kernel_size = ops.cast(2 * ops.round(3 * self._sigma_tensor) + 1, dtype="int32")
38+
ax = ops.arange(-ops.cast(kernel_size // 2, self.compute_dtype) + 1.0,
39+
ops.cast(kernel_size // 2, self.compute_dtype) + 1.0)
40+
kernel_1d = ops.exp(-(ax**2) / (2.0 * self._sigma_tensor**2))
41+
self.kernel_1d = kernel_1d / ops.sum(kernel_1d)
42+
self.built = True
43+
44+
def _separable_gaussian_filter_3d(self, tensor):
45+
"""Apply a 3D Gaussian filter using separable 1D convolutions."""
46+
depth_kernel = ops.reshape(self.kernel_1d, (-1, 1, 1, 1, 1))
47+
tensor = ops.conv(tensor, ops.cast(depth_kernel, dtype=tensor.dtype), padding='same')
48+
49+
height_kernel = ops.reshape(self.kernel_1d, (1, -1, 1, 1, 1))
50+
tensor = ops.conv(tensor, ops.cast(height_kernel, dtype=tensor.dtype), padding='same')
51+
52+
width_kernel = ops.reshape(self.kernel_1d, (1, 1, -1, 1, 1))
53+
tensor = ops.conv(tensor, ops.cast(width_kernel, dtype=tensor.dtype), padding='same')
54+
3555
return tensor
3656

3757
def call(self, inputs):
3858
image_volume, label_volume = inputs
3959
original_image_dtype = image_volume.dtype
60+
original_label_dtype = label_volume.dtype
61+
compute_dtype = self.compute_dtype
4062

4163
was_batched = True
42-
if image_volume.shape.rank == 4:
64+
if len(image_volume.shape) == 4:
4365
was_batched = False
44-
image_volume = tf.expand_dims(image_volume, axis=0)
45-
label_volume = tf.expand_dims(label_volume, axis=0)
66+
image_volume = ops.expand_dims(image_volume, axis=0)
67+
label_volume = ops.expand_dims(label_volume, axis=0)
4668

47-
if self.data_format == "HWDC":
48-
image_volume = tf.transpose(image_volume, perm=[0, 3, 1, 2, 4])
49-
label_volume = tf.transpose(label_volume, perm=[0, 3, 1, 2, 4])
69+
image_volume = ops.cast(image_volume, dtype=compute_dtype)
70+
label_volume = ops.cast(label_volume, dtype=compute_dtype)
5071

51-
image_volume = tf.cast(image_volume, dtype=tf.bfloat16)
52-
input_shape = tf.shape(image_volume)
72+
input_shape = ops.shape(image_volume)
5373
B, D, H, W = input_shape[0], input_shape[1], input_shape[2], input_shape[3]
74+
C = input_shape[4]
5475

55-
coarse_flow = tf.random.uniform(
76+
# 1. Create a coarse random flow field.
77+
coarse_flow = random.uniform(
5678
shape=(B, self.grid_size[0], self.grid_size[1], self.grid_size[2], 3),
57-
minval=-1, maxval=1, dtype=tf.bfloat16)
58-
59-
flow = tf.reshape(coarse_flow, [B * self.grid_size[0], self.grid_size[1], self.grid_size[2], 3])
60-
flow = tf.image.resize(flow, size=[H, W], method='bicubic')
61-
flow = tf.reshape(flow, [B, self.grid_size[0], H, W, 3])
62-
flow = tf.transpose(flow, perm=[0, 2, 3, 1, 4])
63-
flow = tf.reshape(flow, [B * H * W, self.grid_size[0], 3])
64-
flow = tf.image.resize(tf.expand_dims(flow, axis=1), size=[1, D], method='bicubic')
65-
flow = tf.squeeze(flow, axis=1)
66-
flow = tf.reshape(flow, [B, H, W, D, 3])
67-
flow = tf.transpose(flow, perm=[0, 3, 1, 2, 4])
68-
79+
minval=-1, maxval=1, dtype=compute_dtype
80+
)
6981

70-
flow = tf.cast(flow, dtype=tf.bfloat16)
71-
72-
flow_components = tf.unstack(flow, axis=-1)
82+
# 2. Upsample the flow field.
83+
flow = coarse_flow
84+
flow_shape = ops.shape(flow)
85+
flow = ops.reshape(flow, (flow_shape[0] * flow_shape[1], flow_shape[2], flow_shape[3], 3))
86+
flow = ops.image.resize(flow, (H, W), interpolation="bicubic")
87+
flow = ops.reshape(flow, (flow_shape[0], flow_shape[1], H, W, 3))
88+
flow = ops.transpose(flow, (0, 2, 3, 1, 4))
89+
flow_shape = ops.shape(flow)
90+
flow = ops.reshape(flow, (flow_shape[0] * flow_shape[1] * flow_shape[2], flow_shape[3], 1, 3))
91+
flow = ops.image.resize(flow, (D, 1), interpolation="bicubic")
92+
flow = ops.reshape(flow, (flow_shape[0], flow_shape[1], flow_shape[2], D, 3))
93+
flow = ops.transpose(flow, (0, 3, 1, 2, 4))
94+
95+
# 3. Apply Gaussian smoothing.
96+
flow_components = ops.unstack(flow, axis=-1)
7397
smoothed_components = []
7498
for component in flow_components:
75-
smoothed_component = self._separable_gaussian_filter_3d(
76-
component[..., tf.newaxis], self.sigma
77-
)
78-
smoothed_components.append(smoothed_component[..., 0])
79-
smoothed_flow = tf.stack(smoothed_components, axis=-1)
99+
component_reshaped = ops.expand_dims(component, axis=-1)
100+
smoothed_component = self._separable_gaussian_filter_3d(component_reshaped)
101+
smoothed_components.append(ops.squeeze(smoothed_component, axis=-1))
102+
smoothed_flow = ops.stack(smoothed_components, axis=-1)
80103

81-
82-
flow = smoothed_flow * self.alpha
83-
84-
grid_d, grid_h, grid_w = tf.meshgrid(
85-
tf.range(D, dtype=tf.bfloat16),
86-
tf.range(H, dtype=tf.bfloat16),
87-
tf.range(W, dtype=tf.bfloat16),
104+
# 4. Scale the flow field and create warp grid.
105+
flow = smoothed_flow * self._alpha_tensor
106+
grid_d, grid_h, grid_w = ops.meshgrid(
107+
ops.arange(D, dtype=compute_dtype),
108+
ops.arange(H, dtype=compute_dtype),
109+
ops.arange(W, dtype=compute_dtype),
88110
indexing='ij'
89111
)
90-
grid = tf.stack([grid_d, grid_h, grid_w], axis=-1)
112+
grid = ops.stack([grid_d, grid_h, grid_w], axis=-1)
113+
warp_grid = ops.expand_dims(grid, 0) + flow
91114

92115

93-
warp_grid = tf.expand_dims(grid, 0) + flow
116+
batched_coords = ops.transpose(warp_grid, (0, 4, 1, 2, 3))
117+
118+
119+
deformed_images_batched = []
120+
for i in range(B):
121+
122+
image_slice = image_volume[i]
123+
coords = batched_coords[i]
124+
125+
126+
image_slice_transposed = ops.transpose(image_slice, (3, 0, 1, 2))
127+
128+
deformed_channels = []
129+
for c in range(C):
130+
131+
deformed_channel = ops.image.map_coordinates(
132+
image_slice_transposed[c], coords, order=1
133+
)
134+
deformed_channels.append(deformed_channel)
135+
136+
# Stack and transpose back to (D, H, W, C)
137+
deformed_image_slice = ops.stack(deformed_channels, axis=0)
138+
deformed_images_batched.append(ops.transpose(deformed_image_slice, (1, 2, 3, 0)))
139+
140+
deformed_image = ops.stack(deformed_images_batched, axis=0)
141+
142+
# Process Labels: loop over the batch dimension.
143+
deformed_labels_batched = []
144+
for i in range(B):
145+
label_slice = label_volume[i]
146+
coords = batched_coords[i]
147+
148+
149+
label_channel = ops.squeeze(label_slice, axis=-1)
150+
deformed_label_channel = ops.image.map_coordinates(
151+
label_channel, coords, order=0
152+
)
153+
154+
deformed_labels_batched.append(ops.expand_dims(deformed_label_channel, axis=-1))
155+
156+
deformed_label = ops.stack(deformed_labels_batched, axis=0)
94157

95-
warp_grid_floor = tf.floor(warp_grid)
96-
t = warp_grid - warp_grid_floor
97-
98-
d0 = tf.cast(warp_grid_floor[..., 0], tf.int32); h0 = tf.cast(warp_grid_floor[..., 1], tf.int32); w0 = tf.cast(warp_grid_floor[..., 2], tf.int32)
99-
d1 = tf.clip_by_value(d0 + 1, 0, D - 1); h1 = tf.clip_by_value(h0 + 1, 0, H - 1); w1 = tf.clip_by_value(w0 + 1, 0, W - 1)
100-
d0 = tf.clip_by_value(d0, 0, D - 1); h0 = tf.clip_by_value(h0, 0, H - 1); w0 = tf.clip_by_value(w0, 0, W - 1)
101-
102-
c000 = tf.gather_nd(image_volume, tf.stack([d0, h0, w0], axis=-1), batch_dims=1); c001 = tf.gather_nd(image_volume, tf.stack([d0, h0, w1], axis=-1), batch_dims=1)
103-
c010 = tf.gather_nd(image_volume, tf.stack([d0, h1, w0], axis=-1), batch_dims=1); c011 = tf.gather_nd(image_volume, tf.stack([d0, h1, w1], axis=-1), batch_dims=1)
104-
c100 = tf.gather_nd(image_volume, tf.stack([d1, h0, w0], axis=-1), batch_dims=1); c101 = tf.gather_nd(image_volume, tf.stack([d1, h0, w1], axis=-1), batch_dims=1)
105-
c110 = tf.gather_nd(image_volume, tf.stack([d1, h1, w0], axis=-1), batch_dims=1); c111 = tf.gather_nd(image_volume, tf.stack([d1, h1, w1], axis=-1), batch_dims=1)
106-
107-
td, th, tw = t[..., 0:1], t[..., 1:2], t[..., 2:3]
108-
c00 = c000*(1-tw) + c001*tw; c01 = c010*(1-tw) + c011*tw; c10 = c100*(1-tw) + c101*tw; c11 = c110*(1-tw) + c111*tw
109-
c0 = c00*(1-th) + c01*th; c1 = c10*(1-th) + c11*th
110-
deformed_image = c0*(1-td) + c1*td
111-
deformed_image = tf.cast(deformed_image, original_image_dtype)
112-
113-
nearest_indices_float = tf.round(warp_grid)
114-
nearest_d = tf.clip_by_value(tf.cast(nearest_indices_float[..., 0], tf.int32), 0, D - 1)
115-
nearest_h = tf.clip_by_value(tf.cast(nearest_indices_float[..., 1], tf.int32), 0, H - 1)
116-
nearest_w = tf.clip_by_value(tf.cast(nearest_indices_float[..., 2], tf.int32), 0, W - 1)
117-
deformed_label = tf.gather_nd(label_volume, tf.stack([nearest_d, nearest_h, nearest_w], axis=-1), batch_dims=1)
118-
119-
if self.data_format == "HWDC":
120-
deformed_image = tf.transpose(deformed_image, perm=[0, 2, 3, 1, 4])
121-
deformed_label = tf.transpose(deformed_label, perm=[0, 2, 3, 1, 4])
122158

123-
if not was_batched:
124-
deformed_image = tf.squeeze(deformed_image, axis=0)
125-
deformed_label = tf.squeeze(deformed_label, axis=0)
126159

127-
return deformed_image, deformed_label
160+
deformed_image = ops.cast(deformed_image, original_image_dtype)
161+
deformed_label = ops.cast(deformed_label, original_label_dtype)
162+
163+
if not was_batched:
164+
deformed_image = ops.squeeze(deformed_image, axis=0)
165+
deformed_label = ops.squeeze(deformed_label, axis=0)
166+
167+
return deformed_image, deformed_label
168+
169+
def compute_output_shape(self, input_shape):
170+
"""Computes the output shape of the layer."""
171+
image_shape, label_shape = input_shape
172+
return image_shape, label_shape
173+
174+
def get_config(self):
175+
config = super().get_config()
176+
config.update({
177+
"grid_size": self.grid_size,
178+
"alpha": self.alpha,
179+
"sigma": self.sigma,
180+
"data_format": self.data_format,
181+
})
182+
return config

0 commit comments

Comments
 (0)