Skip to content

Commit 3665ebb

Browse files
feat(layers): Add TPU-optimized 3D Elastic Deformation layer
1 parent 049b25d commit 3665ebb

File tree

3 files changed

+201
-0
lines changed

3 files changed

+201
-0
lines changed

keras_hub/api/layers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,7 @@
147147
from keras_hub.src.models.xception.xception_image_converter import (
148148
XceptionImageConverter as XceptionImageConverter,
149149
)
150+
151+
from keras_hub.src.layers.preprocessing.random_elastic_deformation_3d import (
152+
RandomElasticDeformation3D,
153+
)
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import tensorflow as tf
2+
3+
class RandomElasticDeformation3D(tf.keras.layers.Layer):
4+
"""
5+
A high-performance 3D elastic deformation layer optimized for TPUs and GPUs.
6+
... (docstring is the same) ...
7+
"""
8+
def __init__(self,
9+
grid_size=(4, 4, 4),
10+
alpha=35.0,
11+
sigma=2.5,
12+
data_format="DHWC",
13+
**kwargs):
14+
super().__init__(**kwargs)
15+
self.grid_size = grid_size
16+
self.alpha = tf.constant(alpha, dtype=tf.bfloat16)
17+
self.sigma = tf.constant(sigma, dtype=tf.bfloat16)
18+
if data_format not in ["DHWC", "HWDC"]:
19+
raise ValueError("`data_format` must be one of 'DHWC' or 'HWDC'")
20+
self.data_format = data_format
21+
22+
def _separable_gaussian_filter_3d(self, tensor, sigma):
23+
24+
kernel_size = tf.cast(2 * tf.round(3 * sigma) + 1, dtype=tf.int32)
25+
ax = tf.range(-tf.cast(kernel_size // 2, tf.bfloat16) + 1.0,
26+
tf.cast(kernel_size // 2, tf.bfloat16) + 1.0)
27+
kernel_1d = tf.exp(-(ax**2) / (2.0 * self.sigma**2))
28+
kernel_1d = kernel_1d / tf.reduce_sum(kernel_1d)
29+
filter_d = tf.cast(tf.reshape(kernel_1d, [-1, 1, 1, 1, 1]), dtype=tensor.dtype)
30+
filter_h = tf.cast(tf.reshape(kernel_1d, [1, -1, 1, 1, 1]), dtype=tensor.dtype)
31+
filter_w = tf.cast(tf.reshape(kernel_1d, [1, 1, -1, 1, 1]), dtype=tensor.dtype)
32+
tensor = tf.nn.convolution(tensor, filter_d, strides=1, padding='SAME')
33+
tensor = tf.nn.convolution(tensor, filter_h, strides=1, padding='SAME')
34+
tensor = tf.nn.convolution(tensor, filter_w, strides=1, padding='SAME')
35+
return tensor
36+
37+
def call(self, inputs):
38+
image_volume, label_volume = inputs
39+
original_image_dtype = image_volume.dtype
40+
41+
was_batched = True
42+
if image_volume.shape.rank == 4:
43+
was_batched = False
44+
image_volume = tf.expand_dims(image_volume, axis=0)
45+
label_volume = tf.expand_dims(label_volume, axis=0)
46+
47+
if self.data_format == "HWDC":
48+
image_volume = tf.transpose(image_volume, perm=[0, 3, 1, 2, 4])
49+
label_volume = tf.transpose(label_volume, perm=[0, 3, 1, 2, 4])
50+
51+
image_volume = tf.cast(image_volume, dtype=tf.bfloat16)
52+
input_shape = tf.shape(image_volume)
53+
B, D, H, W = input_shape[0], input_shape[1], input_shape[2], input_shape[3]
54+
55+
coarse_flow = tf.random.uniform(
56+
shape=(B, self.grid_size[0], self.grid_size[1], self.grid_size[2], 3),
57+
minval=-1, maxval=1, dtype=tf.bfloat16)
58+
59+
flow = tf.reshape(coarse_flow, [B * self.grid_size[0], self.grid_size[1], self.grid_size[2], 3])
60+
flow = tf.image.resize(flow, size=[H, W], method='bicubic')
61+
flow = tf.reshape(flow, [B, self.grid_size[0], H, W, 3])
62+
flow = tf.transpose(flow, perm=[0, 2, 3, 1, 4])
63+
flow = tf.reshape(flow, [B * H * W, self.grid_size[0], 3])
64+
flow = tf.image.resize(tf.expand_dims(flow, axis=1), size=[1, D], method='bicubic')
65+
flow = tf.squeeze(flow, axis=1)
66+
flow = tf.reshape(flow, [B, H, W, D, 3])
67+
flow = tf.transpose(flow, perm=[0, 3, 1, 2, 4])
68+
69+
70+
flow = tf.cast(flow, dtype=tf.bfloat16)
71+
72+
flow_components = tf.unstack(flow, axis=-1)
73+
smoothed_components = []
74+
for component in flow_components:
75+
smoothed_component = self._separable_gaussian_filter_3d(
76+
component[..., tf.newaxis], self.sigma
77+
)
78+
smoothed_components.append(smoothed_component[..., 0])
79+
smoothed_flow = tf.stack(smoothed_components, axis=-1)
80+
81+
82+
flow = smoothed_flow * self.alpha
83+
84+
grid_d, grid_h, grid_w = tf.meshgrid(
85+
tf.range(D, dtype=tf.bfloat16),
86+
tf.range(H, dtype=tf.bfloat16),
87+
tf.range(W, dtype=tf.bfloat16),
88+
indexing='ij'
89+
)
90+
grid = tf.stack([grid_d, grid_h, grid_w], axis=-1)
91+
92+
93+
warp_grid = tf.expand_dims(grid, 0) + flow
94+
95+
warp_grid_floor = tf.floor(warp_grid)
96+
t = warp_grid - warp_grid_floor
97+
98+
d0 = tf.cast(warp_grid_floor[..., 0], tf.int32); h0 = tf.cast(warp_grid_floor[..., 1], tf.int32); w0 = tf.cast(warp_grid_floor[..., 2], tf.int32)
99+
d1 = tf.clip_by_value(d0 + 1, 0, D - 1); h1 = tf.clip_by_value(h0 + 1, 0, H - 1); w1 = tf.clip_by_value(w0 + 1, 0, W - 1)
100+
d0 = tf.clip_by_value(d0, 0, D - 1); h0 = tf.clip_by_value(h0, 0, H - 1); w0 = tf.clip_by_value(w0, 0, W - 1)
101+
102+
c000 = tf.gather_nd(image_volume, tf.stack([d0, h0, w0], axis=-1), batch_dims=1); c001 = tf.gather_nd(image_volume, tf.stack([d0, h0, w1], axis=-1), batch_dims=1)
103+
c010 = tf.gather_nd(image_volume, tf.stack([d0, h1, w0], axis=-1), batch_dims=1); c011 = tf.gather_nd(image_volume, tf.stack([d0, h1, w1], axis=-1), batch_dims=1)
104+
c100 = tf.gather_nd(image_volume, tf.stack([d1, h0, w0], axis=-1), batch_dims=1); c101 = tf.gather_nd(image_volume, tf.stack([d1, h0, w1], axis=-1), batch_dims=1)
105+
c110 = tf.gather_nd(image_volume, tf.stack([d1, h1, w0], axis=-1), batch_dims=1); c111 = tf.gather_nd(image_volume, tf.stack([d1, h1, w1], axis=-1), batch_dims=1)
106+
107+
td, th, tw = t[..., 0:1], t[..., 1:2], t[..., 2:3]
108+
c00 = c000*(1-tw) + c001*tw; c01 = c010*(1-tw) + c011*tw; c10 = c100*(1-tw) + c101*tw; c11 = c110*(1-tw) + c111*tw
109+
c0 = c00*(1-th) + c01*th; c1 = c10*(1-th) + c11*th
110+
deformed_image = c0*(1-td) + c1*td
111+
deformed_image = tf.cast(deformed_image, original_image_dtype)
112+
113+
nearest_indices_float = tf.round(warp_grid)
114+
nearest_d = tf.clip_by_value(tf.cast(nearest_indices_float[..., 0], tf.int32), 0, D - 1)
115+
nearest_h = tf.clip_by_value(tf.cast(nearest_indices_float[..., 1], tf.int32), 0, H - 1)
116+
nearest_w = tf.clip_by_value(tf.cast(nearest_indices_float[..., 2], tf.int32), 0, W - 1)
117+
deformed_label = tf.gather_nd(label_volume, tf.stack([nearest_d, nearest_h, nearest_w], axis=-1), batch_dims=1)
118+
119+
if self.data_format == "HWDC":
120+
deformed_image = tf.transpose(deformed_image, perm=[0, 2, 3, 1, 4])
121+
deformed_label = tf.transpose(deformed_label, perm=[0, 2, 3, 1, 4])
122+
123+
if not was_batched:
124+
deformed_image = tf.squeeze(deformed_image, axis=0)
125+
deformed_label = tf.squeeze(deformed_label, axis=0)
126+
127+
return deformed_image, deformed_label
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import tensorflow as tf
2+
from tensorflow import keras
3+
from keras_hub.src.layers.preprocessing.random_elastic_deformation_3d import RandomElasticDeformation3D
4+
5+
class RandomElasticDeformation3DTest(tf.test.TestCase):
6+
7+
def test_output_shape_is_same_as_input_dhwc(self):
8+
input_image = tf.random.uniform(shape=(2, 32, 64, 64, 3), dtype=tf.float32)
9+
input_label = tf.random.uniform(shape=(2, 32, 64, 64, 1), maxval=4, dtype=tf.int32)
10+
layer = RandomElasticDeformation3D(data_format="DHWC")
11+
output_image, output_label = layer((input_image, tf.cast(input_label, tf.float32)))
12+
self.assertAllEqual(tf.shape(input_image), tf.shape(output_image))
13+
self.assertAllEqual(tf.shape(input_label), tf.shape(output_label))
14+
15+
def test_output_shape_is_same_as_input_hwdc(self):
16+
input_image = tf.random.uniform(shape=(2, 64, 64, 32, 3), dtype=tf.float32)
17+
input_label = tf.random.uniform(shape=(2, 64, 64, 32, 1), maxval=4, dtype=tf.int32)
18+
layer = RandomElasticDeformation3D(data_format="HWDC")
19+
output_image, output_label = layer((input_image, tf.cast(input_label, tf.float32)))
20+
self.assertAllEqual(tf.shape(input_image), tf.shape(output_image))
21+
self.assertAllEqual(tf.shape(input_label), tf.shape(output_label))
22+
23+
def test_unbatched_input(self):
24+
input_image = tf.random.uniform(shape=(32, 64, 64, 3), dtype=tf.float32)
25+
input_label = tf.random.uniform(shape=(32, 64, 64, 1), maxval=4, dtype=tf.int32)
26+
layer = RandomElasticDeformation3D(data_format="DHWC")
27+
output_image, output_label = layer((input_image, tf.cast(input_label, tf.float32)))
28+
self.assertAllEqual(tf.shape(input_image), tf.shape(output_image))
29+
self.assertEqual(tf.rank(output_image), 4)
30+
31+
def test_dtype_preservation(self):
32+
input_image = tf.random.uniform(shape=(2, 16, 16, 16, 3), dtype=tf.float32)
33+
input_label = tf.random.uniform(shape=(2, 16, 16, 16, 1), maxval=4, dtype=tf.int32)
34+
layer = RandomElasticDeformation3D()
35+
output_image, output_label = layer((input_image, tf.cast(input_label, tf.float32)))
36+
self.assertEqual(output_image.dtype, tf.float32)
37+
self.assertEqual(output_label.dtype, tf.float32)
38+
39+
def test_label_values_are_preserved(self):
40+
input_image = tf.zeros(shape=(1, 16, 16, 16, 1), dtype=tf.float32)
41+
label_arange = tf.experimental.numpy.arange(16**3)
42+
input_label = tf.reshape(label_arange, (1, 16, 16, 16, 1))
43+
input_label = tf.cast(input_label, dtype=tf.float32) % 4
44+
45+
layer = RandomElasticDeformation3D(alpha=80.0, sigma=8.0)
46+
_, output_label = layer((input_image, input_label))
47+
48+
unique_values_tensor = tf.unique(tf.reshape(output_label, [-1]))[0]
49+
50+
51+
expected_values = [0., 1., 2., 3.]
52+
actual_values = unique_values_tensor.numpy().tolist()
53+
self.assertContainsSubset(expected_values, actual_values)
54+
55+
def test_config_serialization(self):
56+
layer = RandomElasticDeformation3D(
57+
grid_size=(3, 3, 3),
58+
alpha=50.0,
59+
sigma=5.0,
60+
data_format="HWDC"
61+
)
62+
config = layer.get_config()
63+
new_layer = RandomElasticDeformation3D.from_config(config)
64+
self.assertEqual(new_layer.grid_size, (3, 3, 3))
65+
self.assertAllClose(new_layer.alpha, tf.constant(50.0, dtype=tf.bfloat16))
66+
self.assertAllClose(new_layer.sigma, tf.constant(5.0, dtype=tf.bfloat16))
67+
self.assertEqual(new_layer.data_format, "HWDC")
68+
69+
if __name__ == "__main__":
70+
tf.test.main()

0 commit comments

Comments
 (0)