From 3665ebb4b67521572059fa6fbea430f4c6753ddf Mon Sep 17 00:00:00 2001 From: ashpakshaikh26732 Date: Thu, 25 Sep 2025 09:57:31 +0000 Subject: [PATCH 1/5] feat(layers): Add TPU-optimized 3D Elastic Deformation layer --- keras_hub/api/layers/__init__.py | 4 + .../random_elastic_deformation_3d.py | 127 ++++++++++++++++++ .../random_elastic_deformation_3d_test.py | 70 ++++++++++ 3 files changed, 201 insertions(+) create mode 100644 keras_hub/src/layers/preprocessing/random_elastic_deformation_3d.py create mode 100644 keras_hub/src/layers/preprocessing/random_elastic_deformation_3d_test.py diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index f90c214d6b..98a0298e61 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -147,3 +147,7 @@ from keras_hub.src.models.xception.xception_image_converter import ( XceptionImageConverter as XceptionImageConverter, ) + +from keras_hub.src.layers.preprocessing.random_elastic_deformation_3d import ( + RandomElasticDeformation3D, +) diff --git a/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d.py b/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d.py new file mode 100644 index 0000000000..69b43e558d --- /dev/null +++ b/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d.py @@ -0,0 +1,127 @@ +import tensorflow as tf + +class RandomElasticDeformation3D(tf.keras.layers.Layer): + """ + A high-performance 3D elastic deformation layer optimized for TPUs and GPUs. + ... (docstring is the same) ... + """ + def __init__(self, + grid_size=(4, 4, 4), + alpha=35.0, + sigma=2.5, + data_format="DHWC", + **kwargs): + super().__init__(**kwargs) + self.grid_size = grid_size + self.alpha = tf.constant(alpha, dtype=tf.bfloat16) + self.sigma = tf.constant(sigma, dtype=tf.bfloat16) + if data_format not in ["DHWC", "HWDC"]: + raise ValueError("`data_format` must be one of 'DHWC' or 'HWDC'") + self.data_format = data_format + + def _separable_gaussian_filter_3d(self, tensor, sigma): + + kernel_size = tf.cast(2 * tf.round(3 * sigma) + 1, dtype=tf.int32) + ax = tf.range(-tf.cast(kernel_size // 2, tf.bfloat16) + 1.0, + tf.cast(kernel_size // 2, tf.bfloat16) + 1.0) + kernel_1d = tf.exp(-(ax**2) / (2.0 * self.sigma**2)) + kernel_1d = kernel_1d / tf.reduce_sum(kernel_1d) + filter_d = tf.cast(tf.reshape(kernel_1d, [-1, 1, 1, 1, 1]), dtype=tensor.dtype) + filter_h = tf.cast(tf.reshape(kernel_1d, [1, -1, 1, 1, 1]), dtype=tensor.dtype) + filter_w = tf.cast(tf.reshape(kernel_1d, [1, 1, -1, 1, 1]), dtype=tensor.dtype) + tensor = tf.nn.convolution(tensor, filter_d, strides=1, padding='SAME') + tensor = tf.nn.convolution(tensor, filter_h, strides=1, padding='SAME') + tensor = tf.nn.convolution(tensor, filter_w, strides=1, padding='SAME') + return tensor + + def call(self, inputs): + image_volume, label_volume = inputs + original_image_dtype = image_volume.dtype + + was_batched = True + if image_volume.shape.rank == 4: + was_batched = False + image_volume = tf.expand_dims(image_volume, axis=0) + label_volume = tf.expand_dims(label_volume, axis=0) + + if self.data_format == "HWDC": + image_volume = tf.transpose(image_volume, perm=[0, 3, 1, 2, 4]) + label_volume = tf.transpose(label_volume, perm=[0, 3, 1, 2, 4]) + + image_volume = tf.cast(image_volume, dtype=tf.bfloat16) + input_shape = tf.shape(image_volume) + B, D, H, W = input_shape[0], input_shape[1], input_shape[2], input_shape[3] + + coarse_flow = tf.random.uniform( + shape=(B, self.grid_size[0], self.grid_size[1], self.grid_size[2], 3), + minval=-1, maxval=1, dtype=tf.bfloat16) + + flow = tf.reshape(coarse_flow, [B * self.grid_size[0], self.grid_size[1], self.grid_size[2], 3]) + flow = tf.image.resize(flow, size=[H, W], method='bicubic') + flow = tf.reshape(flow, [B, self.grid_size[0], H, W, 3]) + flow = tf.transpose(flow, perm=[0, 2, 3, 1, 4]) + flow = tf.reshape(flow, [B * H * W, self.grid_size[0], 3]) + flow = tf.image.resize(tf.expand_dims(flow, axis=1), size=[1, D], method='bicubic') + flow = tf.squeeze(flow, axis=1) + flow = tf.reshape(flow, [B, H, W, D, 3]) + flow = tf.transpose(flow, perm=[0, 3, 1, 2, 4]) + + + flow = tf.cast(flow, dtype=tf.bfloat16) + + flow_components = tf.unstack(flow, axis=-1) + smoothed_components = [] + for component in flow_components: + smoothed_component = self._separable_gaussian_filter_3d( + component[..., tf.newaxis], self.sigma + ) + smoothed_components.append(smoothed_component[..., 0]) + smoothed_flow = tf.stack(smoothed_components, axis=-1) + + + flow = smoothed_flow * self.alpha + + grid_d, grid_h, grid_w = tf.meshgrid( + tf.range(D, dtype=tf.bfloat16), + tf.range(H, dtype=tf.bfloat16), + tf.range(W, dtype=tf.bfloat16), + indexing='ij' + ) + grid = tf.stack([grid_d, grid_h, grid_w], axis=-1) + + + warp_grid = tf.expand_dims(grid, 0) + flow + + warp_grid_floor = tf.floor(warp_grid) + t = warp_grid - warp_grid_floor + + d0 = tf.cast(warp_grid_floor[..., 0], tf.int32); h0 = tf.cast(warp_grid_floor[..., 1], tf.int32); w0 = tf.cast(warp_grid_floor[..., 2], tf.int32) + d1 = tf.clip_by_value(d0 + 1, 0, D - 1); h1 = tf.clip_by_value(h0 + 1, 0, H - 1); w1 = tf.clip_by_value(w0 + 1, 0, W - 1) + d0 = tf.clip_by_value(d0, 0, D - 1); h0 = tf.clip_by_value(h0, 0, H - 1); w0 = tf.clip_by_value(w0, 0, W - 1) + + c000 = tf.gather_nd(image_volume, tf.stack([d0, h0, w0], axis=-1), batch_dims=1); c001 = tf.gather_nd(image_volume, tf.stack([d0, h0, w1], axis=-1), batch_dims=1) + c010 = tf.gather_nd(image_volume, tf.stack([d0, h1, w0], axis=-1), batch_dims=1); c011 = tf.gather_nd(image_volume, tf.stack([d0, h1, w1], axis=-1), batch_dims=1) + c100 = tf.gather_nd(image_volume, tf.stack([d1, h0, w0], axis=-1), batch_dims=1); c101 = tf.gather_nd(image_volume, tf.stack([d1, h0, w1], axis=-1), batch_dims=1) + c110 = tf.gather_nd(image_volume, tf.stack([d1, h1, w0], axis=-1), batch_dims=1); c111 = tf.gather_nd(image_volume, tf.stack([d1, h1, w1], axis=-1), batch_dims=1) + + td, th, tw = t[..., 0:1], t[..., 1:2], t[..., 2:3] + c00 = c000*(1-tw) + c001*tw; c01 = c010*(1-tw) + c011*tw; c10 = c100*(1-tw) + c101*tw; c11 = c110*(1-tw) + c111*tw + c0 = c00*(1-th) + c01*th; c1 = c10*(1-th) + c11*th + deformed_image = c0*(1-td) + c1*td + deformed_image = tf.cast(deformed_image, original_image_dtype) + + nearest_indices_float = tf.round(warp_grid) + nearest_d = tf.clip_by_value(tf.cast(nearest_indices_float[..., 0], tf.int32), 0, D - 1) + nearest_h = tf.clip_by_value(tf.cast(nearest_indices_float[..., 1], tf.int32), 0, H - 1) + nearest_w = tf.clip_by_value(tf.cast(nearest_indices_float[..., 2], tf.int32), 0, W - 1) + deformed_label = tf.gather_nd(label_volume, tf.stack([nearest_d, nearest_h, nearest_w], axis=-1), batch_dims=1) + + if self.data_format == "HWDC": + deformed_image = tf.transpose(deformed_image, perm=[0, 2, 3, 1, 4]) + deformed_label = tf.transpose(deformed_label, perm=[0, 2, 3, 1, 4]) + + if not was_batched: + deformed_image = tf.squeeze(deformed_image, axis=0) + deformed_label = tf.squeeze(deformed_label, axis=0) + + return deformed_image, deformed_label \ No newline at end of file diff --git a/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d_test.py b/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d_test.py new file mode 100644 index 0000000000..69ad64ba24 --- /dev/null +++ b/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d_test.py @@ -0,0 +1,70 @@ +import tensorflow as tf +from tensorflow import keras +from keras_hub.src.layers.preprocessing.random_elastic_deformation_3d import RandomElasticDeformation3D + +class RandomElasticDeformation3DTest(tf.test.TestCase): + + def test_output_shape_is_same_as_input_dhwc(self): + input_image = tf.random.uniform(shape=(2, 32, 64, 64, 3), dtype=tf.float32) + input_label = tf.random.uniform(shape=(2, 32, 64, 64, 1), maxval=4, dtype=tf.int32) + layer = RandomElasticDeformation3D(data_format="DHWC") + output_image, output_label = layer((input_image, tf.cast(input_label, tf.float32))) + self.assertAllEqual(tf.shape(input_image), tf.shape(output_image)) + self.assertAllEqual(tf.shape(input_label), tf.shape(output_label)) + + def test_output_shape_is_same_as_input_hwdc(self): + input_image = tf.random.uniform(shape=(2, 64, 64, 32, 3), dtype=tf.float32) + input_label = tf.random.uniform(shape=(2, 64, 64, 32, 1), maxval=4, dtype=tf.int32) + layer = RandomElasticDeformation3D(data_format="HWDC") + output_image, output_label = layer((input_image, tf.cast(input_label, tf.float32))) + self.assertAllEqual(tf.shape(input_image), tf.shape(output_image)) + self.assertAllEqual(tf.shape(input_label), tf.shape(output_label)) + + def test_unbatched_input(self): + input_image = tf.random.uniform(shape=(32, 64, 64, 3), dtype=tf.float32) + input_label = tf.random.uniform(shape=(32, 64, 64, 1), maxval=4, dtype=tf.int32) + layer = RandomElasticDeformation3D(data_format="DHWC") + output_image, output_label = layer((input_image, tf.cast(input_label, tf.float32))) + self.assertAllEqual(tf.shape(input_image), tf.shape(output_image)) + self.assertEqual(tf.rank(output_image), 4) + + def test_dtype_preservation(self): + input_image = tf.random.uniform(shape=(2, 16, 16, 16, 3), dtype=tf.float32) + input_label = tf.random.uniform(shape=(2, 16, 16, 16, 1), maxval=4, dtype=tf.int32) + layer = RandomElasticDeformation3D() + output_image, output_label = layer((input_image, tf.cast(input_label, tf.float32))) + self.assertEqual(output_image.dtype, tf.float32) + self.assertEqual(output_label.dtype, tf.float32) + + def test_label_values_are_preserved(self): + input_image = tf.zeros(shape=(1, 16, 16, 16, 1), dtype=tf.float32) + label_arange = tf.experimental.numpy.arange(16**3) + input_label = tf.reshape(label_arange, (1, 16, 16, 16, 1)) + input_label = tf.cast(input_label, dtype=tf.float32) % 4 + + layer = RandomElasticDeformation3D(alpha=80.0, sigma=8.0) + _, output_label = layer((input_image, input_label)) + + unique_values_tensor = tf.unique(tf.reshape(output_label, [-1]))[0] + + + expected_values = [0., 1., 2., 3.] + actual_values = unique_values_tensor.numpy().tolist() + self.assertContainsSubset(expected_values, actual_values) + + def test_config_serialization(self): + layer = RandomElasticDeformation3D( + grid_size=(3, 3, 3), + alpha=50.0, + sigma=5.0, + data_format="HWDC" + ) + config = layer.get_config() + new_layer = RandomElasticDeformation3D.from_config(config) + self.assertEqual(new_layer.grid_size, (3, 3, 3)) + self.assertAllClose(new_layer.alpha, tf.constant(50.0, dtype=tf.bfloat16)) + self.assertAllClose(new_layer.sigma, tf.constant(5.0, dtype=tf.bfloat16)) + self.assertEqual(new_layer.data_format, "HWDC") + +if __name__ == "__main__": + tf.test.main() \ No newline at end of file From 9793ab9c2172106e59b7a9948a73bc9308d17009 Mon Sep 17 00:00:00 2001 From: ashpakshaikh26732 Date: Thu, 25 Sep 2025 12:25:05 +0000 Subject: [PATCH 2/5] Fix: Address all review feedback and resolve test failures --- .../random_elastic_deformation_3d.py | 239 +++++++++++------- .../random_elastic_deformation_3d_test.py | 149 ++++++----- 2 files changed, 235 insertions(+), 153 deletions(-) diff --git a/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d.py b/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d.py index 69b43e558d..a7edeceaab 100644 --- a/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d.py +++ b/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d.py @@ -1,127 +1,182 @@ -import tensorflow as tf +from keras import ops +from keras import layers +from keras import random -class RandomElasticDeformation3D(tf.keras.layers.Layer): +class RandomElasticDeformation3D(layers.Layer): """ - A high-performance 3D elastic deformation layer optimized for TPUs and GPUs. - ... (docstring is the same) ... + A high-performance 3D elastic deformation layer optimized for TPUs. + + This implementation leverages the layer's compute_dtype (e.g., bfloat16) + to potentially halve memory bandwidth requirements and uses a vectorized + mapping for maximum speed. """ def __init__(self, grid_size=(4, 4, 4), alpha=35.0, sigma=2.5, - data_format="DHWC", + data_format="channels_last", **kwargs): super().__init__(**kwargs) + self.grid_size = grid_size - self.alpha = tf.constant(alpha, dtype=tf.bfloat16) - self.sigma = tf.constant(sigma, dtype=tf.bfloat16) - if data_format not in ["DHWC", "HWDC"]: - raise ValueError("`data_format` must be one of 'DHWC' or 'HWDC'") + self.alpha = alpha + self.sigma = sigma self.data_format = data_format - - def _separable_gaussian_filter_3d(self, tensor, sigma): - - kernel_size = tf.cast(2 * tf.round(3 * sigma) + 1, dtype=tf.int32) - ax = tf.range(-tf.cast(kernel_size // 2, tf.bfloat16) + 1.0, - tf.cast(kernel_size // 2, tf.bfloat16) + 1.0) - kernel_1d = tf.exp(-(ax**2) / (2.0 * self.sigma**2)) - kernel_1d = kernel_1d / tf.reduce_sum(kernel_1d) - filter_d = tf.cast(tf.reshape(kernel_1d, [-1, 1, 1, 1, 1]), dtype=tensor.dtype) - filter_h = tf.cast(tf.reshape(kernel_1d, [1, -1, 1, 1, 1]), dtype=tensor.dtype) - filter_w = tf.cast(tf.reshape(kernel_1d, [1, 1, -1, 1, 1]), dtype=tensor.dtype) - tensor = tf.nn.convolution(tensor, filter_d, strides=1, padding='SAME') - tensor = tf.nn.convolution(tensor, filter_h, strides=1, padding='SAME') - tensor = tf.nn.convolution(tensor, filter_w, strides=1, padding='SAME') + if data_format not in ["channels_last", "channels_first"]: + raise ValueError( + "`data_format` must be one of 'channels_last' or " + f"'channels_first'. Received: {data_format}" + ) + + def build(self, input_shape): + """Create tensor state in build to respect the layer's dtype.""" + self._alpha_tensor = ops.convert_to_tensor(self.alpha, dtype=self.compute_dtype) + self._sigma_tensor = ops.convert_to_tensor(self.sigma, dtype=self.compute_dtype) + + # Pre-compute the 1D Gaussian kernel + kernel_size = ops.cast(2 * ops.round(3 * self._sigma_tensor) + 1, dtype="int32") + ax = ops.arange(-ops.cast(kernel_size // 2, self.compute_dtype) + 1.0, + ops.cast(kernel_size // 2, self.compute_dtype) + 1.0) + kernel_1d = ops.exp(-(ax**2) / (2.0 * self._sigma_tensor**2)) + self.kernel_1d = kernel_1d / ops.sum(kernel_1d) + self.built = True + + def _separable_gaussian_filter_3d(self, tensor): + """Apply a 3D Gaussian filter using separable 1D convolutions.""" + depth_kernel = ops.reshape(self.kernel_1d, (-1, 1, 1, 1, 1)) + tensor = ops.conv(tensor, ops.cast(depth_kernel, dtype=tensor.dtype), padding='same') + + height_kernel = ops.reshape(self.kernel_1d, (1, -1, 1, 1, 1)) + tensor = ops.conv(tensor, ops.cast(height_kernel, dtype=tensor.dtype), padding='same') + + width_kernel = ops.reshape(self.kernel_1d, (1, 1, -1, 1, 1)) + tensor = ops.conv(tensor, ops.cast(width_kernel, dtype=tensor.dtype), padding='same') + return tensor def call(self, inputs): image_volume, label_volume = inputs original_image_dtype = image_volume.dtype + original_label_dtype = label_volume.dtype + compute_dtype = self.compute_dtype was_batched = True - if image_volume.shape.rank == 4: + if len(image_volume.shape) == 4: was_batched = False - image_volume = tf.expand_dims(image_volume, axis=0) - label_volume = tf.expand_dims(label_volume, axis=0) + image_volume = ops.expand_dims(image_volume, axis=0) + label_volume = ops.expand_dims(label_volume, axis=0) - if self.data_format == "HWDC": - image_volume = tf.transpose(image_volume, perm=[0, 3, 1, 2, 4]) - label_volume = tf.transpose(label_volume, perm=[0, 3, 1, 2, 4]) + image_volume = ops.cast(image_volume, dtype=compute_dtype) + label_volume = ops.cast(label_volume, dtype=compute_dtype) - image_volume = tf.cast(image_volume, dtype=tf.bfloat16) - input_shape = tf.shape(image_volume) + input_shape = ops.shape(image_volume) B, D, H, W = input_shape[0], input_shape[1], input_shape[2], input_shape[3] + C = input_shape[4] - coarse_flow = tf.random.uniform( + # 1. Create a coarse random flow field. + coarse_flow = random.uniform( shape=(B, self.grid_size[0], self.grid_size[1], self.grid_size[2], 3), - minval=-1, maxval=1, dtype=tf.bfloat16) - - flow = tf.reshape(coarse_flow, [B * self.grid_size[0], self.grid_size[1], self.grid_size[2], 3]) - flow = tf.image.resize(flow, size=[H, W], method='bicubic') - flow = tf.reshape(flow, [B, self.grid_size[0], H, W, 3]) - flow = tf.transpose(flow, perm=[0, 2, 3, 1, 4]) - flow = tf.reshape(flow, [B * H * W, self.grid_size[0], 3]) - flow = tf.image.resize(tf.expand_dims(flow, axis=1), size=[1, D], method='bicubic') - flow = tf.squeeze(flow, axis=1) - flow = tf.reshape(flow, [B, H, W, D, 3]) - flow = tf.transpose(flow, perm=[0, 3, 1, 2, 4]) - + minval=-1, maxval=1, dtype=compute_dtype + ) - flow = tf.cast(flow, dtype=tf.bfloat16) - - flow_components = tf.unstack(flow, axis=-1) + # 2. Upsample the flow field. + flow = coarse_flow + flow_shape = ops.shape(flow) + flow = ops.reshape(flow, (flow_shape[0] * flow_shape[1], flow_shape[2], flow_shape[3], 3)) + flow = ops.image.resize(flow, (H, W), interpolation="bicubic") + flow = ops.reshape(flow, (flow_shape[0], flow_shape[1], H, W, 3)) + flow = ops.transpose(flow, (0, 2, 3, 1, 4)) + flow_shape = ops.shape(flow) + flow = ops.reshape(flow, (flow_shape[0] * flow_shape[1] * flow_shape[2], flow_shape[3], 1, 3)) + flow = ops.image.resize(flow, (D, 1), interpolation="bicubic") + flow = ops.reshape(flow, (flow_shape[0], flow_shape[1], flow_shape[2], D, 3)) + flow = ops.transpose(flow, (0, 3, 1, 2, 4)) + + # 3. Apply Gaussian smoothing. + flow_components = ops.unstack(flow, axis=-1) smoothed_components = [] for component in flow_components: - smoothed_component = self._separable_gaussian_filter_3d( - component[..., tf.newaxis], self.sigma - ) - smoothed_components.append(smoothed_component[..., 0]) - smoothed_flow = tf.stack(smoothed_components, axis=-1) + component_reshaped = ops.expand_dims(component, axis=-1) + smoothed_component = self._separable_gaussian_filter_3d(component_reshaped) + smoothed_components.append(ops.squeeze(smoothed_component, axis=-1)) + smoothed_flow = ops.stack(smoothed_components, axis=-1) - - flow = smoothed_flow * self.alpha - - grid_d, grid_h, grid_w = tf.meshgrid( - tf.range(D, dtype=tf.bfloat16), - tf.range(H, dtype=tf.bfloat16), - tf.range(W, dtype=tf.bfloat16), + # 4. Scale the flow field and create warp grid. + flow = smoothed_flow * self._alpha_tensor + grid_d, grid_h, grid_w = ops.meshgrid( + ops.arange(D, dtype=compute_dtype), + ops.arange(H, dtype=compute_dtype), + ops.arange(W, dtype=compute_dtype), indexing='ij' ) - grid = tf.stack([grid_d, grid_h, grid_w], axis=-1) + grid = ops.stack([grid_d, grid_h, grid_w], axis=-1) + warp_grid = ops.expand_dims(grid, 0) + flow - warp_grid = tf.expand_dims(grid, 0) + flow + batched_coords = ops.transpose(warp_grid, (0, 4, 1, 2, 3)) + + + deformed_images_batched = [] + for i in range(B): + + image_slice = image_volume[i] + coords = batched_coords[i] + + + image_slice_transposed = ops.transpose(image_slice, (3, 0, 1, 2)) + + deformed_channels = [] + for c in range(C): + + deformed_channel = ops.image.map_coordinates( + image_slice_transposed[c], coords, order=1 + ) + deformed_channels.append(deformed_channel) + + # Stack and transpose back to (D, H, W, C) + deformed_image_slice = ops.stack(deformed_channels, axis=0) + deformed_images_batched.append(ops.transpose(deformed_image_slice, (1, 2, 3, 0))) + + deformed_image = ops.stack(deformed_images_batched, axis=0) + + # Process Labels: loop over the batch dimension. + deformed_labels_batched = [] + for i in range(B): + label_slice = label_volume[i] + coords = batched_coords[i] + + + label_channel = ops.squeeze(label_slice, axis=-1) + deformed_label_channel = ops.image.map_coordinates( + label_channel, coords, order=0 + ) + + deformed_labels_batched.append(ops.expand_dims(deformed_label_channel, axis=-1)) + + deformed_label = ops.stack(deformed_labels_batched, axis=0) - warp_grid_floor = tf.floor(warp_grid) - t = warp_grid - warp_grid_floor - - d0 = tf.cast(warp_grid_floor[..., 0], tf.int32); h0 = tf.cast(warp_grid_floor[..., 1], tf.int32); w0 = tf.cast(warp_grid_floor[..., 2], tf.int32) - d1 = tf.clip_by_value(d0 + 1, 0, D - 1); h1 = tf.clip_by_value(h0 + 1, 0, H - 1); w1 = tf.clip_by_value(w0 + 1, 0, W - 1) - d0 = tf.clip_by_value(d0, 0, D - 1); h0 = tf.clip_by_value(h0, 0, H - 1); w0 = tf.clip_by_value(w0, 0, W - 1) - - c000 = tf.gather_nd(image_volume, tf.stack([d0, h0, w0], axis=-1), batch_dims=1); c001 = tf.gather_nd(image_volume, tf.stack([d0, h0, w1], axis=-1), batch_dims=1) - c010 = tf.gather_nd(image_volume, tf.stack([d0, h1, w0], axis=-1), batch_dims=1); c011 = tf.gather_nd(image_volume, tf.stack([d0, h1, w1], axis=-1), batch_dims=1) - c100 = tf.gather_nd(image_volume, tf.stack([d1, h0, w0], axis=-1), batch_dims=1); c101 = tf.gather_nd(image_volume, tf.stack([d1, h0, w1], axis=-1), batch_dims=1) - c110 = tf.gather_nd(image_volume, tf.stack([d1, h1, w0], axis=-1), batch_dims=1); c111 = tf.gather_nd(image_volume, tf.stack([d1, h1, w1], axis=-1), batch_dims=1) - - td, th, tw = t[..., 0:1], t[..., 1:2], t[..., 2:3] - c00 = c000*(1-tw) + c001*tw; c01 = c010*(1-tw) + c011*tw; c10 = c100*(1-tw) + c101*tw; c11 = c110*(1-tw) + c111*tw - c0 = c00*(1-th) + c01*th; c1 = c10*(1-th) + c11*th - deformed_image = c0*(1-td) + c1*td - deformed_image = tf.cast(deformed_image, original_image_dtype) - - nearest_indices_float = tf.round(warp_grid) - nearest_d = tf.clip_by_value(tf.cast(nearest_indices_float[..., 0], tf.int32), 0, D - 1) - nearest_h = tf.clip_by_value(tf.cast(nearest_indices_float[..., 1], tf.int32), 0, H - 1) - nearest_w = tf.clip_by_value(tf.cast(nearest_indices_float[..., 2], tf.int32), 0, W - 1) - deformed_label = tf.gather_nd(label_volume, tf.stack([nearest_d, nearest_h, nearest_w], axis=-1), batch_dims=1) - - if self.data_format == "HWDC": - deformed_image = tf.transpose(deformed_image, perm=[0, 2, 3, 1, 4]) - deformed_label = tf.transpose(deformed_label, perm=[0, 2, 3, 1, 4]) - if not was_batched: - deformed_image = tf.squeeze(deformed_image, axis=0) - deformed_label = tf.squeeze(deformed_label, axis=0) - return deformed_image, deformed_label \ No newline at end of file + deformed_image = ops.cast(deformed_image, original_image_dtype) + deformed_label = ops.cast(deformed_label, original_label_dtype) + + if not was_batched: + deformed_image = ops.squeeze(deformed_image, axis=0) + deformed_label = ops.squeeze(deformed_label, axis=0) + + return deformed_image, deformed_label + + def compute_output_shape(self, input_shape): + """Computes the output shape of the layer.""" + image_shape, label_shape = input_shape + return image_shape, label_shape + + def get_config(self): + config = super().get_config() + config.update({ + "grid_size": self.grid_size, + "alpha": self.alpha, + "sigma": self.sigma, + "data_format": self.data_format, + }) + return config \ No newline at end of file diff --git a/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d_test.py b/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d_test.py index 69ad64ba24..63ef7d1c28 100644 --- a/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d_test.py +++ b/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d_test.py @@ -1,70 +1,97 @@ -import tensorflow as tf -from tensorflow import keras + +import os +import numpy as np +import keras +from keras import Model +from keras import Input +from keras import ops from keras_hub.src.layers.preprocessing.random_elastic_deformation_3d import RandomElasticDeformation3D +from keras_hub.src.tests.test_case import TestCase -class RandomElasticDeformation3DTest(tf.test.TestCase): - - def test_output_shape_is_same_as_input_dhwc(self): - input_image = tf.random.uniform(shape=(2, 32, 64, 64, 3), dtype=tf.float32) - input_label = tf.random.uniform(shape=(2, 32, 64, 64, 1), maxval=4, dtype=tf.int32) - layer = RandomElasticDeformation3D(data_format="DHWC") - output_image, output_label = layer((input_image, tf.cast(input_label, tf.float32))) - self.assertAllEqual(tf.shape(input_image), tf.shape(output_image)) - self.assertAllEqual(tf.shape(input_label), tf.shape(output_label)) - - def test_output_shape_is_same_as_input_hwdc(self): - input_image = tf.random.uniform(shape=(2, 64, 64, 32, 3), dtype=tf.float32) - input_label = tf.random.uniform(shape=(2, 64, 64, 32, 1), maxval=4, dtype=tf.int32) - layer = RandomElasticDeformation3D(data_format="HWDC") - output_image, output_label = layer((input_image, tf.cast(input_label, tf.float32))) - self.assertAllEqual(tf.shape(input_image), tf.shape(output_image)) - self.assertAllEqual(tf.shape(input_label), tf.shape(output_label)) - - def test_unbatched_input(self): - input_image = tf.random.uniform(shape=(32, 64, 64, 3), dtype=tf.float32) - input_label = tf.random.uniform(shape=(32, 64, 64, 1), maxval=4, dtype=tf.int32) - layer = RandomElasticDeformation3D(data_format="DHWC") - output_image, output_label = layer((input_image, tf.cast(input_label, tf.float32))) - self.assertAllEqual(tf.shape(input_image), tf.shape(output_image)) - self.assertEqual(tf.rank(output_image), 4) - - def test_dtype_preservation(self): - input_image = tf.random.uniform(shape=(2, 16, 16, 16, 3), dtype=tf.float32) - input_label = tf.random.uniform(shape=(2, 16, 16, 16, 1), maxval=4, dtype=tf.int32) - layer = RandomElasticDeformation3D() - output_image, output_label = layer((input_image, tf.cast(input_label, tf.float32))) - self.assertEqual(output_image.dtype, tf.float32) - self.assertEqual(output_label.dtype, tf.float32) - def test_label_values_are_preserved(self): - input_image = tf.zeros(shape=(1, 16, 16, 16, 1), dtype=tf.float32) - label_arange = tf.experimental.numpy.arange(16**3) - input_label = tf.reshape(label_arange, (1, 16, 16, 16, 1)) - input_label = tf.cast(input_label, dtype=tf.float32) % 4 - - layer = RandomElasticDeformation3D(alpha=80.0, sigma=8.0) - _, output_label = layer((input_image, input_label)) - - unique_values_tensor = tf.unique(tf.reshape(output_label, [-1]))[0] - +class RandomElasticDeformation3DTest(TestCase): - expected_values = [0., 1., 2., 3.] - actual_values = unique_values_tensor.numpy().tolist() - self.assertContainsSubset(expected_values, actual_values) - - def test_config_serialization(self): + + + + def test_layer_basics(self): + + layer = RandomElasticDeformation3D( + grid_size=(4, 4, 4), + alpha=10.0, + sigma=2.0, + ) + image = ops.ones((2, 32, 32, 32, 3), dtype="float32") + label = ops.ones((2, 32, 32, 32, 1), dtype="int32") + + output_image, output_label = layer((image, label)) + + # Check shapes + self.assertEqual(ops.shape(image), ops.shape(output_image)) + self.assertEqual(ops.shape(label), ops.shape(output_label)) + + # Check dtypes + self.assertEqual(image.dtype, output_image.dtype) + self.assertEqual(label.dtype, output_label.dtype) + + + + def test_serialization(self): + # 1. Instantiate the layer layer = RandomElasticDeformation3D( grid_size=(3, 3, 3), alpha=50.0, sigma=5.0, - data_format="HWDC" ) - config = layer.get_config() - new_layer = RandomElasticDeformation3D.from_config(config) - self.assertEqual(new_layer.grid_size, (3, 3, 3)) - self.assertAllClose(new_layer.alpha, tf.constant(50.0, dtype=tf.bfloat16)) - self.assertAllClose(new_layer.sigma, tf.constant(5.0, dtype=tf.bfloat16)) - self.assertEqual(new_layer.data_format, "HWDC") - -if __name__ == "__main__": - tf.test.main() \ No newline at end of file + + # 2. Create dummy input data + image_data = ops.ones((2, 16, 16, 16, 3), dtype="float32") + label_data = ops.ones((2, 16, 16, 16, 1), dtype="int32") + input_data = (image_data, label_data) + + # 3. Build a functional Model that uses the layer + image_input = Input(shape=(16, 16, 16, 3), dtype="float32") + label_input = Input(shape=(16, 16, 16, 1), dtype="int32") + outputs = layer((image_input, label_input)) + model = Model(inputs=[image_input, label_input], outputs=outputs) + + # 4. Get the output of the original model + original_output_image, original_output_label = model(input_data) + + # 5. Save and load the model + path = os.path.join(self.get_temp_dir(), "model.keras") + model.save(path, save_format="keras_v3") + loaded_model = keras.models.load_model( + path, custom_objects={"RandomElasticDeformation3D": RandomElasticDeformation3D} + ) + + # 6. Get the output of the loaded model + loaded_output_image, loaded_output_label = loaded_model(input_data) + + # 7. Assert that the outputs are the same + np.testing.assert_allclose( + ops.convert_to_numpy(original_output_image), + ops.convert_to_numpy(loaded_output_image), + ) + np.testing.assert_array_equal( + ops.convert_to_numpy(original_output_label), + ops.convert_to_numpy(loaded_output_label), + ) + + + + def test_label_values_are_preserved(self): + + image = ops.zeros(shape=(1, 16, 16, 16, 1), dtype="float32") + + + label_arange = ops.arange(16**3, dtype="int32") + label = ops.reshape(label_arange, (1, 16, 16, 16, 1)) % 4 + + layer = RandomElasticDeformation3D(alpha=80.0, sigma=8.0) + _, output_label = layer((image, label)) + + + output_values = set(np.unique(ops.convert_to_numpy(output_label))) + expected_values = {0, 1, 2, 3} + self.assertLessEqual(output_values, expected_values) \ No newline at end of file From cb9295470e59a20ffcde43472b8636f25d41610e Mon Sep 17 00:00:00 2001 From: ashpakshaikh26732 Date: Thu, 25 Sep 2025 14:10:14 +0000 Subject: [PATCH 3/5] feat(layers): updated 3D elastic deformation layer --- .../random_elastic_deformation_3d.py | 124 ++++++------------ .../random_elastic_deformation_3d_test.py | 47 ++----- 2 files changed, 50 insertions(+), 121 deletions(-) diff --git a/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d.py b/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d.py index a7edeceaab..016d8c8d03 100644 --- a/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d.py +++ b/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d.py @@ -1,3 +1,5 @@ +# Add this import +from keras import backend from keras import ops from keras import layers from keras import random @@ -5,10 +7,6 @@ class RandomElasticDeformation3D(layers.Layer): """ A high-performance 3D elastic deformation layer optimized for TPUs. - - This implementation leverages the layer's compute_dtype (e.g., bfloat16) - to potentially halve memory bandwidth requirements and uses a vectorized - mapping for maximum speed. """ def __init__(self, grid_size=(4, 4, 4), @@ -17,41 +15,29 @@ def __init__(self, data_format="channels_last", **kwargs): super().__init__(**kwargs) - self.grid_size = grid_size self.alpha = alpha self.sigma = sigma self.data_format = data_format if data_format not in ["channels_last", "channels_first"]: - raise ValueError( - "`data_format` must be one of 'channels_last' or " - f"'channels_first'. Received: {data_format}" - ) - + raise ValueError(f"`data_format` must be one of 'channels_last' or 'channels_first'. Received: {data_format}") + def build(self, input_shape): - """Create tensor state in build to respect the layer's dtype.""" self._alpha_tensor = ops.convert_to_tensor(self.alpha, dtype=self.compute_dtype) self._sigma_tensor = ops.convert_to_tensor(self.sigma, dtype=self.compute_dtype) - - # Pre-compute the 1D Gaussian kernel kernel_size = ops.cast(2 * ops.round(3 * self._sigma_tensor) + 1, dtype="int32") - ax = ops.arange(-ops.cast(kernel_size // 2, self.compute_dtype) + 1.0, - ops.cast(kernel_size // 2, self.compute_dtype) + 1.0) + ax = ops.arange(-ops.cast(kernel_size // 2, self.compute_dtype) + 1.0, ops.cast(kernel_size // 2, self.compute_dtype) + 1.0) kernel_1d = ops.exp(-(ax**2) / (2.0 * self._sigma_tensor**2)) self.kernel_1d = kernel_1d / ops.sum(kernel_1d) self.built = True def _separable_gaussian_filter_3d(self, tensor): - """Apply a 3D Gaussian filter using separable 1D convolutions.""" depth_kernel = ops.reshape(self.kernel_1d, (-1, 1, 1, 1, 1)) tensor = ops.conv(tensor, ops.cast(depth_kernel, dtype=tensor.dtype), padding='same') - height_kernel = ops.reshape(self.kernel_1d, (1, -1, 1, 1, 1)) tensor = ops.conv(tensor, ops.cast(height_kernel, dtype=tensor.dtype), padding='same') - width_kernel = ops.reshape(self.kernel_1d, (1, 1, -1, 1, 1)) tensor = ops.conv(tensor, ops.cast(width_kernel, dtype=tensor.dtype), padding='same') - return tensor def call(self, inputs): @@ -70,16 +56,10 @@ def call(self, inputs): label_volume = ops.cast(label_volume, dtype=compute_dtype) input_shape = ops.shape(image_volume) - B, D, H, W = input_shape[0], input_shape[1], input_shape[2], input_shape[3] - C = input_shape[4] - - # 1. Create a coarse random flow field. - coarse_flow = random.uniform( - shape=(B, self.grid_size[0], self.grid_size[1], self.grid_size[2], 3), - minval=-1, maxval=1, dtype=compute_dtype - ) - - # 2. Upsample the flow field. + B, D, H, W, C = input_shape[0], input_shape[1], input_shape[2], input_shape[3], input_shape[4] + + coarse_flow = random.uniform(shape=(B, self.grid_size[0], self.grid_size[1], self.grid_size[2], 3), minval=-1, maxval=1, dtype=compute_dtype) + flow = coarse_flow flow_shape = ops.shape(flow) flow = ops.reshape(flow, (flow_shape[0] * flow_shape[1], flow_shape[2], flow_shape[3], 3)) @@ -91,71 +71,49 @@ def call(self, inputs): flow = ops.image.resize(flow, (D, 1), interpolation="bicubic") flow = ops.reshape(flow, (flow_shape[0], flow_shape[1], flow_shape[2], D, 3)) flow = ops.transpose(flow, (0, 3, 1, 2, 4)) - - # 3. Apply Gaussian smoothing. + flow_components = ops.unstack(flow, axis=-1) smoothed_components = [] for component in flow_components: - component_reshaped = ops.expand_dims(component, axis=-1) - smoothed_component = self._separable_gaussian_filter_3d(component_reshaped) - smoothed_components.append(ops.squeeze(smoothed_component, axis=-1)) + smoothed_components.append(ops.squeeze(self._separable_gaussian_filter_3d(ops.expand_dims(component, axis=-1)), axis=-1)) smoothed_flow = ops.stack(smoothed_components, axis=-1) - # 4. Scale the flow field and create warp grid. flow = smoothed_flow * self._alpha_tensor - grid_d, grid_h, grid_w = ops.meshgrid( - ops.arange(D, dtype=compute_dtype), - ops.arange(H, dtype=compute_dtype), - ops.arange(W, dtype=compute_dtype), - indexing='ij' - ) + grid_d, grid_h, grid_w = ops.meshgrid(ops.arange(D, dtype=compute_dtype), ops.arange(H, dtype=compute_dtype), ops.arange(W, dtype=compute_dtype), indexing='ij') grid = ops.stack([grid_d, grid_h, grid_w], axis=-1) warp_grid = ops.expand_dims(grid, 0) + flow - batched_coords = ops.transpose(warp_grid, (0, 4, 1, 2, 3)) - - deformed_images_batched = [] - for i in range(B): - - image_slice = image_volume[i] - coords = batched_coords[i] - - - image_slice_transposed = ops.transpose(image_slice, (3, 0, 1, 2)) - + def perform_map(elems): + image_slice, label_slice, coords = elems deformed_channels = [] + image_slice_transposed = ops.transpose(image_slice, (3, 0, 1, 2)) + # The channel dimension C is a static value when the graph is built for c in range(C): - - deformed_channel = ops.image.map_coordinates( - image_slice_transposed[c], coords, order=1 - ) - deformed_channels.append(deformed_channel) - - # Stack and transpose back to (D, H, W, C) + deformed_channels.append(ops.image.map_coordinates(image_slice_transposed[c], coords, order=1)) deformed_image_slice = ops.stack(deformed_channels, axis=0) - deformed_images_batched.append(ops.transpose(deformed_image_slice, (1, 2, 3, 0))) - - deformed_image = ops.stack(deformed_images_batched, axis=0) - - # Process Labels: loop over the batch dimension. - deformed_labels_batched = [] - for i in range(B): - label_slice = label_volume[i] - coords = batched_coords[i] - - + deformed_image_slice = ops.transpose(deformed_image_slice, (1, 2, 3, 0)) label_channel = ops.squeeze(label_slice, axis=-1) - deformed_label_channel = ops.image.map_coordinates( - label_channel, coords, order=0 - ) - - deformed_labels_batched.append(ops.expand_dims(deformed_label_channel, axis=-1)) - - deformed_label = ops.stack(deformed_labels_batched, axis=0) - - + deformed_label_channel = ops.image.map_coordinates(label_channel, coords, order=0) + deformed_label_slice = ops.expand_dims(deformed_label_channel, axis=-1) + return deformed_image_slice, deformed_label_slice + + if backend.backend() == "tensorflow": + import tensorflow as tf + deformed_image, deformed_label = tf.map_fn(perform_map, elems=(image_volume, label_volume, batched_coords), dtype=(compute_dtype, compute_dtype)) + elif backend.backend() == "jax": + import jax + deformed_image, deformed_label = jax.lax.map(perform_map, xs=(image_volume, label_volume, batched_coords)) + else: + deformed_images_list = [] + deformed_labels_list = [] + for i in range(B): + img_slice, lbl_slice = perform_map((image_volume[i], label_volume[i], batched_coords[i])) + deformed_images_list.append(img_slice) + deformed_labels_list.append(lbl_slice) + deformed_image = ops.stack(deformed_images_list, axis=0) + deformed_label = ops.stack(deformed_labels_list, axis=0) deformed_image = ops.cast(deformed_image, original_image_dtype) deformed_label = ops.cast(deformed_label, original_label_dtype) @@ -167,16 +125,10 @@ def call(self, inputs): return deformed_image, deformed_label def compute_output_shape(self, input_shape): - """Computes the output shape of the layer.""" image_shape, label_shape = input_shape return image_shape, label_shape def get_config(self): config = super().get_config() - config.update({ - "grid_size": self.grid_size, - "alpha": self.alpha, - "sigma": self.sigma, - "data_format": self.data_format, - }) + config.update({"grid_size": self.grid_size, "alpha": self.alpha, "sigma": self.sigma, "data_format": self.data_format}) return config \ No newline at end of file diff --git a/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d_test.py b/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d_test.py index 63ef7d1c28..c84b8e13f7 100644 --- a/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d_test.py +++ b/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d_test.py @@ -1,4 +1,5 @@ - +# Add keras.utils for the random seed +from keras import utils import os import numpy as np import keras @@ -10,12 +11,9 @@ class RandomElasticDeformation3DTest(TestCase): - - - - def test_layer_basics(self): - + # --- BEST PRACTICE: Add a seed for reproducibility --- + utils.set_random_seed(0) layer = RandomElasticDeformation3D( grid_size=(4, 4, 4), alpha=10.0, @@ -23,52 +21,37 @@ def test_layer_basics(self): ) image = ops.ones((2, 32, 32, 32, 3), dtype="float32") label = ops.ones((2, 32, 32, 32, 1), dtype="int32") - output_image, output_label = layer((image, label)) - - # Check shapes self.assertEqual(ops.shape(image), ops.shape(output_image)) self.assertEqual(ops.shape(label), ops.shape(output_label)) - - # Check dtypes self.assertEqual(image.dtype, output_image.dtype) self.assertEqual(label.dtype, output_label.dtype) - - def test_serialization(self): - # 1. Instantiate the layer + # --- BEST PRACTICE: Add a seed for reproducibility --- + utils.set_random_seed(0) layer = RandomElasticDeformation3D( grid_size=(3, 3, 3), alpha=50.0, sigma=5.0, ) - - # 2. Create dummy input data image_data = ops.ones((2, 16, 16, 16, 3), dtype="float32") label_data = ops.ones((2, 16, 16, 16, 1), dtype="int32") input_data = (image_data, label_data) - - # 3. Build a functional Model that uses the layer image_input = Input(shape=(16, 16, 16, 3), dtype="float32") label_input = Input(shape=(16, 16, 16, 1), dtype="int32") outputs = layer((image_input, label_input)) model = Model(inputs=[image_input, label_input], outputs=outputs) - - # 4. Get the output of the original model original_output_image, original_output_label = model(input_data) - - # 5. Save and load the model path = os.path.join(self.get_temp_dir(), "model.keras") - model.save(path, save_format="keras_v3") + + # --- FIX: Remove the deprecated save_format argument --- + model.save(path) + loaded_model = keras.models.load_model( path, custom_objects={"RandomElasticDeformation3D": RandomElasticDeformation3D} ) - - # 6. Get the output of the loaded model loaded_output_image, loaded_output_label = loaded_model(input_data) - - # 7. Assert that the outputs are the same np.testing.assert_allclose( ops.convert_to_numpy(original_output_image), ops.convert_to_numpy(loaded_output_image), @@ -78,20 +61,14 @@ def test_serialization(self): ops.convert_to_numpy(loaded_output_label), ) - - def test_label_values_are_preserved(self): - + # --- BEST PRACTICE: Add a seed for reproducibility --- + utils.set_random_seed(0) image = ops.zeros(shape=(1, 16, 16, 16, 1), dtype="float32") - - label_arange = ops.arange(16**3, dtype="int32") label = ops.reshape(label_arange, (1, 16, 16, 16, 1)) % 4 - layer = RandomElasticDeformation3D(alpha=80.0, sigma=8.0) _, output_label = layer((image, label)) - - output_values = set(np.unique(ops.convert_to_numpy(output_label))) expected_values = {0, 1, 2, 3} self.assertLessEqual(output_values, expected_values) \ No newline at end of file From fa19ac96881de680111c3ad7489dd10f8ea07443 Mon Sep 17 00:00:00 2001 From: ashpakshaikh26732 Date: Thu, 25 Sep 2025 15:35:28 +0000 Subject: [PATCH 4/5] feat(layers): Add 3D elastic deformation layer --- keras_hub/api/layers/__init__.py | 4 ---- .../random_elastic_deformation_3d.py | 7 +++++- .../random_elastic_deformation_3d_test.py | 23 ++++++++++++------- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index 98a0298e61..f90c214d6b 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -147,7 +147,3 @@ from keras_hub.src.models.xception.xception_image_converter import ( XceptionImageConverter as XceptionImageConverter, ) - -from keras_hub.src.layers.preprocessing.random_elastic_deformation_3d import ( - RandomElasticDeformation3D, -) diff --git a/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d.py b/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d.py index 016d8c8d03..177fdb5366 100644 --- a/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d.py +++ b/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d.py @@ -8,6 +8,7 @@ class RandomElasticDeformation3D(layers.Layer): """ A high-performance 3D elastic deformation layer optimized for TPUs. """ + def __init__(self, grid_size=(4, 4, 4), alpha=35.0, @@ -20,7 +21,11 @@ def __init__(self, self.sigma = sigma self.data_format = data_format if data_format not in ["channels_last", "channels_first"]: - raise ValueError(f"`data_format` must be one of 'channels_last' or 'channels_first'. Received: {data_format}") + message = ( + "`data_format` must be one of 'channels_last' or " + f"'channels_first'. Received: {self.data_format}" + ) + raise ValueError(message) def build(self, input_shape): self._alpha_tensor = ops.convert_to_tensor(self.alpha, dtype=self.compute_dtype) diff --git a/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d_test.py b/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d_test.py index c84b8e13f7..dddafdbc13 100644 --- a/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d_test.py +++ b/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d_test.py @@ -1,12 +1,16 @@ # Add keras.utils for the random seed -from keras import utils import os -import numpy as np + import keras -from keras import Model +import numpy as np from keras import Input +from keras import Model from keras import ops -from keras_hub.src.layers.preprocessing.random_elastic_deformation_3d import RandomElasticDeformation3D +from keras import utils + +from keras_hub.src.layers.preprocessing.random_elastic_deformation_3d import ( + RandomElasticDeformation3D, +) from keras_hub.src.tests.test_case import TestCase @@ -44,12 +48,15 @@ def test_serialization(self): model = Model(inputs=[image_input, label_input], outputs=outputs) original_output_image, original_output_label = model(input_data) path = os.path.join(self.get_temp_dir(), "model.keras") - + # --- FIX: Remove the deprecated save_format argument --- model.save(path) - + loaded_model = keras.models.load_model( - path, custom_objects={"RandomElasticDeformation3D": RandomElasticDeformation3D} + path, + custom_objects={ + "RandomElasticDeformation3D": RandomElasticDeformation3D + }, ) loaded_output_image, loaded_output_label = loaded_model(input_data) np.testing.assert_allclose( @@ -71,4 +78,4 @@ def test_label_values_are_preserved(self): _, output_label = layer((image, label)) output_values = set(np.unique(ops.convert_to_numpy(output_label))) expected_values = {0, 1, 2, 3} - self.assertLessEqual(output_values, expected_values) \ No newline at end of file + self.assertLessEqual(output_values, expected_values) From 4cf6ee0079b80ed5a600f90ee011a8c916471340 Mon Sep 17 00:00:00 2001 From: ashpakshaikh26732 Date: Thu, 25 Sep 2025 15:42:25 +0000 Subject: [PATCH 5/5] feat(layers): Add 3D elastic deformation layer --- .../random_elastic_deformation_3d.py | 176 ++++++++++++++---- .../random_elastic_deformation_3d_test.py | 19 +- 2 files changed, 149 insertions(+), 46 deletions(-) diff --git a/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d.py b/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d.py index 177fdb5366..98b1a892ed 100644 --- a/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d.py +++ b/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d.py @@ -1,25 +1,31 @@ # Add this import from keras import backend -from keras import ops from keras import layers +from keras import ops from keras import random + class RandomElasticDeformation3D(layers.Layer): """ A high-performance 3D elastic deformation layer optimized for TPUs. """ - def __init__(self, - grid_size=(4, 4, 4), - alpha=35.0, - sigma=2.5, - data_format="channels_last", - **kwargs): + def __init__( + self, + grid_size=(4, 4, 4), + alpha=35.0, + sigma=2.5, + data_format="channels_last", + seed=None, + **kwargs, + ): super().__init__(**kwargs) self.grid_size = grid_size + self.seed = seed self.alpha = alpha self.sigma = sigma self.data_format = data_format + self._rng = random.SeedGenerator(seed) if seed is not None else None if data_format not in ["channels_last", "channels_first"]: message = ( "`data_format` must be one of 'channels_last' or " @@ -28,21 +34,36 @@ def __init__(self, raise ValueError(message) def build(self, input_shape): - self._alpha_tensor = ops.convert_to_tensor(self.alpha, dtype=self.compute_dtype) - self._sigma_tensor = ops.convert_to_tensor(self.sigma, dtype=self.compute_dtype) - kernel_size = ops.cast(2 * ops.round(3 * self._sigma_tensor) + 1, dtype="int32") - ax = ops.arange(-ops.cast(kernel_size // 2, self.compute_dtype) + 1.0, ops.cast(kernel_size // 2, self.compute_dtype) + 1.0) + self._alpha_tensor = ops.convert_to_tensor( + self.alpha, dtype=self.compute_dtype + ) + self._sigma_tensor = ops.convert_to_tensor( + self.sigma, dtype=self.compute_dtype + ) + kernel_size = ops.cast( + 2 * ops.round(3 * self._sigma_tensor) + 1, dtype="int32" + ) + ax = ops.arange( + -ops.cast(kernel_size // 2, self.compute_dtype) + 1.0, + ops.cast(kernel_size // 2, self.compute_dtype) + 1.0, + ) kernel_1d = ops.exp(-(ax**2) / (2.0 * self._sigma_tensor**2)) self.kernel_1d = kernel_1d / ops.sum(kernel_1d) self.built = True def _separable_gaussian_filter_3d(self, tensor): depth_kernel = ops.reshape(self.kernel_1d, (-1, 1, 1, 1, 1)) - tensor = ops.conv(tensor, ops.cast(depth_kernel, dtype=tensor.dtype), padding='same') + tensor = ops.conv( + tensor, ops.cast(depth_kernel, dtype=tensor.dtype), padding="same" + ) height_kernel = ops.reshape(self.kernel_1d, (1, -1, 1, 1, 1)) - tensor = ops.conv(tensor, ops.cast(height_kernel, dtype=tensor.dtype), padding='same') + tensor = ops.conv( + tensor, ops.cast(height_kernel, dtype=tensor.dtype), padding="same" + ) width_kernel = ops.reshape(self.kernel_1d, (1, 1, -1, 1, 1)) - tensor = ops.conv(tensor, ops.cast(width_kernel, dtype=tensor.dtype), padding='same') + tensor = ops.conv( + tensor, ops.cast(width_kernel, dtype=tensor.dtype), padding="same" + ) return tensor def call(self, inputs): @@ -61,33 +82,90 @@ def call(self, inputs): label_volume = ops.cast(label_volume, dtype=compute_dtype) input_shape = ops.shape(image_volume) - B, D, H, W, C = input_shape[0], input_shape[1], input_shape[2], input_shape[3], input_shape[4] - - coarse_flow = random.uniform(shape=(B, self.grid_size[0], self.grid_size[1], self.grid_size[2], 3), minval=-1, maxval=1, dtype=compute_dtype) - + B, D, H, W, C = ( + input_shape[0], + input_shape[1], + input_shape[2], + input_shape[3], + input_shape[4], + ) + + if self._rng is not None: + coarse_flow = random.uniform( + shape=( + B, + self.grid_size[0], + self.grid_size[1], + self.grid_size[2], + 3, + ), + minval=-1, + maxval=1, + dtype=compute_dtype, + seed=self._rng, + ) + else: + coarse_flow = random.uniform( + shape=( + B, + self.grid_size[0], + self.grid_size[1], + self.grid_size[2], + 3, + ), + minval=-1, + maxval=1, + dtype=compute_dtype, + ) + flow = coarse_flow flow_shape = ops.shape(flow) - flow = ops.reshape(flow, (flow_shape[0] * flow_shape[1], flow_shape[2], flow_shape[3], 3)) + flow = ops.reshape( + flow, + (flow_shape[0] * flow_shape[1], flow_shape[2], flow_shape[3], 3), + ) flow = ops.image.resize(flow, (H, W), interpolation="bicubic") flow = ops.reshape(flow, (flow_shape[0], flow_shape[1], H, W, 3)) flow = ops.transpose(flow, (0, 2, 3, 1, 4)) flow_shape = ops.shape(flow) - flow = ops.reshape(flow, (flow_shape[0] * flow_shape[1] * flow_shape[2], flow_shape[3], 1, 3)) + flow = ops.reshape( + flow, + ( + flow_shape[0] * flow_shape[1] * flow_shape[2], + flow_shape[3], + 1, + 3, + ), + ) flow = ops.image.resize(flow, (D, 1), interpolation="bicubic") - flow = ops.reshape(flow, (flow_shape[0], flow_shape[1], flow_shape[2], D, 3)) + flow = ops.reshape( + flow, (flow_shape[0], flow_shape[1], flow_shape[2], D, 3) + ) flow = ops.transpose(flow, (0, 3, 1, 2, 4)) - + flow_components = ops.unstack(flow, axis=-1) smoothed_components = [] for component in flow_components: - smoothed_components.append(ops.squeeze(self._separable_gaussian_filter_3d(ops.expand_dims(component, axis=-1)), axis=-1)) + smoothed_components.append( + ops.squeeze( + self._separable_gaussian_filter_3d( + ops.expand_dims(component, axis=-1) + ), + axis=-1, + ) + ) smoothed_flow = ops.stack(smoothed_components, axis=-1) - + flow = smoothed_flow * self._alpha_tensor - grid_d, grid_h, grid_w = ops.meshgrid(ops.arange(D, dtype=compute_dtype), ops.arange(H, dtype=compute_dtype), ops.arange(W, dtype=compute_dtype), indexing='ij') + grid_d, grid_h, grid_w = ops.meshgrid( + ops.arange(D, dtype=compute_dtype), + ops.arange(H, dtype=compute_dtype), + ops.arange(W, dtype=compute_dtype), + indexing="ij", + ) grid = ops.stack([grid_d, grid_h, grid_w], axis=-1) warp_grid = ops.expand_dims(grid, 0) + flow - + batched_coords = ops.transpose(warp_grid, (0, 4, 1, 2, 3)) def perform_map(elems): @@ -96,25 +174,45 @@ def perform_map(elems): image_slice_transposed = ops.transpose(image_slice, (3, 0, 1, 2)) # The channel dimension C is a static value when the graph is built for c in range(C): - deformed_channels.append(ops.image.map_coordinates(image_slice_transposed[c], coords, order=1)) + deformed_channels.append( + ops.image.map_coordinates( + image_slice_transposed[c], coords, order=1 + ) + ) deformed_image_slice = ops.stack(deformed_channels, axis=0) - deformed_image_slice = ops.transpose(deformed_image_slice, (1, 2, 3, 0)) + deformed_image_slice = ops.transpose( + deformed_image_slice, (1, 2, 3, 0) + ) label_channel = ops.squeeze(label_slice, axis=-1) - deformed_label_channel = ops.image.map_coordinates(label_channel, coords, order=0) - deformed_label_slice = ops.expand_dims(deformed_label_channel, axis=-1) + deformed_label_channel = ops.image.map_coordinates( + label_channel, coords, order=0 + ) + deformed_label_slice = ops.expand_dims( + deformed_label_channel, axis=-1 + ) return deformed_image_slice, deformed_label_slice if backend.backend() == "tensorflow": import tensorflow as tf - deformed_image, deformed_label = tf.map_fn(perform_map, elems=(image_volume, label_volume, batched_coords), dtype=(compute_dtype, compute_dtype)) + + deformed_image, deformed_label = tf.map_fn( + perform_map, + elems=(image_volume, label_volume, batched_coords), + dtype=(compute_dtype, compute_dtype), + ) elif backend.backend() == "jax": import jax - deformed_image, deformed_label = jax.lax.map(perform_map, xs=(image_volume, label_volume, batched_coords)) + + deformed_image, deformed_label = jax.lax.map( + perform_map, xs=(image_volume, label_volume, batched_coords) + ) else: deformed_images_list = [] deformed_labels_list = [] for i in range(B): - img_slice, lbl_slice = perform_map((image_volume[i], label_volume[i], batched_coords[i])) + img_slice, lbl_slice = perform_map( + (image_volume[i], label_volume[i], batched_coords[i]) + ) deformed_images_list.append(img_slice) deformed_labels_list.append(lbl_slice) deformed_image = ops.stack(deformed_images_list, axis=0) @@ -135,5 +233,13 @@ def compute_output_shape(self, input_shape): def get_config(self): config = super().get_config() - config.update({"grid_size": self.grid_size, "alpha": self.alpha, "sigma": self.sigma, "data_format": self.data_format}) - return config \ No newline at end of file + config.update( + { + "grid_size": self.grid_size, + "alpha": self.alpha, + "sigma": self.sigma, + "data_format": self.data_format, + "seed": self.seed, + } + ) + return config diff --git a/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d_test.py b/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d_test.py index dddafdbc13..d4a6ef19c5 100644 --- a/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d_test.py +++ b/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d_test.py @@ -1,4 +1,3 @@ -# Add keras.utils for the random seed import os import keras @@ -16,12 +15,9 @@ class RandomElasticDeformation3DTest(TestCase): def test_layer_basics(self): - # --- BEST PRACTICE: Add a seed for reproducibility --- utils.set_random_seed(0) layer = RandomElasticDeformation3D( - grid_size=(4, 4, 4), - alpha=10.0, - sigma=2.0, + grid_size=(4, 4, 4), alpha=10.0, sigma=2.0, seed=0 ) image = ops.ones((2, 32, 32, 32, 3), dtype="float32") label = ops.ones((2, 32, 32, 32, 1), dtype="int32") @@ -32,36 +28,38 @@ def test_layer_basics(self): self.assertEqual(label.dtype, output_label.dtype) def test_serialization(self): - # --- BEST PRACTICE: Add a seed for reproducibility --- - utils.set_random_seed(0) layer = RandomElasticDeformation3D( grid_size=(3, 3, 3), alpha=50.0, sigma=5.0, + seed=0, ) image_data = ops.ones((2, 16, 16, 16, 3), dtype="float32") label_data = ops.ones((2, 16, 16, 16, 1), dtype="int32") input_data = (image_data, label_data) + image_input = Input(shape=(16, 16, 16, 3), dtype="float32") label_input = Input(shape=(16, 16, 16, 1), dtype="int32") outputs = layer((image_input, label_input)) model = Model(inputs=[image_input, label_input], outputs=outputs) + original_output_image, original_output_label = model(input_data) - path = os.path.join(self.get_temp_dir(), "model.keras") - # --- FIX: Remove the deprecated save_format argument --- + path = os.path.join(self.get_temp_dir(), "model.keras") model.save(path) - loaded_model = keras.models.load_model( path, custom_objects={ "RandomElasticDeformation3D": RandomElasticDeformation3D }, ) + loaded_output_image, loaded_output_label = loaded_model(input_data) + np.testing.assert_allclose( ops.convert_to_numpy(original_output_image), ops.convert_to_numpy(loaded_output_image), + atol=1e-6, ) np.testing.assert_array_equal( ops.convert_to_numpy(original_output_label), @@ -69,7 +67,6 @@ def test_serialization(self): ) def test_label_values_are_preserved(self): - # --- BEST PRACTICE: Add a seed for reproducibility --- utils.set_random_seed(0) image = ops.zeros(shape=(1, 16, 16, 16, 1), dtype="float32") label_arange = ops.arange(16**3, dtype="int32")