diff --git a/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d.py b/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d.py new file mode 100644 index 0000000000..98b1a892ed --- /dev/null +++ b/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d.py @@ -0,0 +1,245 @@ +# Add this import +from keras import backend +from keras import layers +from keras import ops +from keras import random + + +class RandomElasticDeformation3D(layers.Layer): + """ + A high-performance 3D elastic deformation layer optimized for TPUs. + """ + + def __init__( + self, + grid_size=(4, 4, 4), + alpha=35.0, + sigma=2.5, + data_format="channels_last", + seed=None, + **kwargs, + ): + super().__init__(**kwargs) + self.grid_size = grid_size + self.seed = seed + self.alpha = alpha + self.sigma = sigma + self.data_format = data_format + self._rng = random.SeedGenerator(seed) if seed is not None else None + if data_format not in ["channels_last", "channels_first"]: + message = ( + "`data_format` must be one of 'channels_last' or " + f"'channels_first'. Received: {self.data_format}" + ) + raise ValueError(message) + + def build(self, input_shape): + self._alpha_tensor = ops.convert_to_tensor( + self.alpha, dtype=self.compute_dtype + ) + self._sigma_tensor = ops.convert_to_tensor( + self.sigma, dtype=self.compute_dtype + ) + kernel_size = ops.cast( + 2 * ops.round(3 * self._sigma_tensor) + 1, dtype="int32" + ) + ax = ops.arange( + -ops.cast(kernel_size // 2, self.compute_dtype) + 1.0, + ops.cast(kernel_size // 2, self.compute_dtype) + 1.0, + ) + kernel_1d = ops.exp(-(ax**2) / (2.0 * self._sigma_tensor**2)) + self.kernel_1d = kernel_1d / ops.sum(kernel_1d) + self.built = True + + def _separable_gaussian_filter_3d(self, tensor): + depth_kernel = ops.reshape(self.kernel_1d, (-1, 1, 1, 1, 1)) + tensor = ops.conv( + tensor, ops.cast(depth_kernel, dtype=tensor.dtype), padding="same" + ) + height_kernel = ops.reshape(self.kernel_1d, (1, -1, 1, 1, 1)) + tensor = ops.conv( + tensor, ops.cast(height_kernel, dtype=tensor.dtype), padding="same" + ) + width_kernel = ops.reshape(self.kernel_1d, (1, 1, -1, 1, 1)) + tensor = ops.conv( + tensor, ops.cast(width_kernel, dtype=tensor.dtype), padding="same" + ) + return tensor + + def call(self, inputs): + image_volume, label_volume = inputs + original_image_dtype = image_volume.dtype + original_label_dtype = label_volume.dtype + compute_dtype = self.compute_dtype + + was_batched = True + if len(image_volume.shape) == 4: + was_batched = False + image_volume = ops.expand_dims(image_volume, axis=0) + label_volume = ops.expand_dims(label_volume, axis=0) + + image_volume = ops.cast(image_volume, dtype=compute_dtype) + label_volume = ops.cast(label_volume, dtype=compute_dtype) + + input_shape = ops.shape(image_volume) + B, D, H, W, C = ( + input_shape[0], + input_shape[1], + input_shape[2], + input_shape[3], + input_shape[4], + ) + + if self._rng is not None: + coarse_flow = random.uniform( + shape=( + B, + self.grid_size[0], + self.grid_size[1], + self.grid_size[2], + 3, + ), + minval=-1, + maxval=1, + dtype=compute_dtype, + seed=self._rng, + ) + else: + coarse_flow = random.uniform( + shape=( + B, + self.grid_size[0], + self.grid_size[1], + self.grid_size[2], + 3, + ), + minval=-1, + maxval=1, + dtype=compute_dtype, + ) + + flow = coarse_flow + flow_shape = ops.shape(flow) + flow = ops.reshape( + flow, + (flow_shape[0] * flow_shape[1], flow_shape[2], flow_shape[3], 3), + ) + flow = ops.image.resize(flow, (H, W), interpolation="bicubic") + flow = ops.reshape(flow, (flow_shape[0], flow_shape[1], H, W, 3)) + flow = ops.transpose(flow, (0, 2, 3, 1, 4)) + flow_shape = ops.shape(flow) + flow = ops.reshape( + flow, + ( + flow_shape[0] * flow_shape[1] * flow_shape[2], + flow_shape[3], + 1, + 3, + ), + ) + flow = ops.image.resize(flow, (D, 1), interpolation="bicubic") + flow = ops.reshape( + flow, (flow_shape[0], flow_shape[1], flow_shape[2], D, 3) + ) + flow = ops.transpose(flow, (0, 3, 1, 2, 4)) + + flow_components = ops.unstack(flow, axis=-1) + smoothed_components = [] + for component in flow_components: + smoothed_components.append( + ops.squeeze( + self._separable_gaussian_filter_3d( + ops.expand_dims(component, axis=-1) + ), + axis=-1, + ) + ) + smoothed_flow = ops.stack(smoothed_components, axis=-1) + + flow = smoothed_flow * self._alpha_tensor + grid_d, grid_h, grid_w = ops.meshgrid( + ops.arange(D, dtype=compute_dtype), + ops.arange(H, dtype=compute_dtype), + ops.arange(W, dtype=compute_dtype), + indexing="ij", + ) + grid = ops.stack([grid_d, grid_h, grid_w], axis=-1) + warp_grid = ops.expand_dims(grid, 0) + flow + + batched_coords = ops.transpose(warp_grid, (0, 4, 1, 2, 3)) + + def perform_map(elems): + image_slice, label_slice, coords = elems + deformed_channels = [] + image_slice_transposed = ops.transpose(image_slice, (3, 0, 1, 2)) + # The channel dimension C is a static value when the graph is built + for c in range(C): + deformed_channels.append( + ops.image.map_coordinates( + image_slice_transposed[c], coords, order=1 + ) + ) + deformed_image_slice = ops.stack(deformed_channels, axis=0) + deformed_image_slice = ops.transpose( + deformed_image_slice, (1, 2, 3, 0) + ) + label_channel = ops.squeeze(label_slice, axis=-1) + deformed_label_channel = ops.image.map_coordinates( + label_channel, coords, order=0 + ) + deformed_label_slice = ops.expand_dims( + deformed_label_channel, axis=-1 + ) + return deformed_image_slice, deformed_label_slice + + if backend.backend() == "tensorflow": + import tensorflow as tf + + deformed_image, deformed_label = tf.map_fn( + perform_map, + elems=(image_volume, label_volume, batched_coords), + dtype=(compute_dtype, compute_dtype), + ) + elif backend.backend() == "jax": + import jax + + deformed_image, deformed_label = jax.lax.map( + perform_map, xs=(image_volume, label_volume, batched_coords) + ) + else: + deformed_images_list = [] + deformed_labels_list = [] + for i in range(B): + img_slice, lbl_slice = perform_map( + (image_volume[i], label_volume[i], batched_coords[i]) + ) + deformed_images_list.append(img_slice) + deformed_labels_list.append(lbl_slice) + deformed_image = ops.stack(deformed_images_list, axis=0) + deformed_label = ops.stack(deformed_labels_list, axis=0) + + deformed_image = ops.cast(deformed_image, original_image_dtype) + deformed_label = ops.cast(deformed_label, original_label_dtype) + + if not was_batched: + deformed_image = ops.squeeze(deformed_image, axis=0) + deformed_label = ops.squeeze(deformed_label, axis=0) + + return deformed_image, deformed_label + + def compute_output_shape(self, input_shape): + image_shape, label_shape = input_shape + return image_shape, label_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "grid_size": self.grid_size, + "alpha": self.alpha, + "sigma": self.sigma, + "data_format": self.data_format, + "seed": self.seed, + } + ) + return config diff --git a/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d_test.py b/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d_test.py new file mode 100644 index 0000000000..d4a6ef19c5 --- /dev/null +++ b/keras_hub/src/layers/preprocessing/random_elastic_deformation_3d_test.py @@ -0,0 +1,78 @@ +import os + +import keras +import numpy as np +from keras import Input +from keras import Model +from keras import ops +from keras import utils + +from keras_hub.src.layers.preprocessing.random_elastic_deformation_3d import ( + RandomElasticDeformation3D, +) +from keras_hub.src.tests.test_case import TestCase + + +class RandomElasticDeformation3DTest(TestCase): + def test_layer_basics(self): + utils.set_random_seed(0) + layer = RandomElasticDeformation3D( + grid_size=(4, 4, 4), alpha=10.0, sigma=2.0, seed=0 + ) + image = ops.ones((2, 32, 32, 32, 3), dtype="float32") + label = ops.ones((2, 32, 32, 32, 1), dtype="int32") + output_image, output_label = layer((image, label)) + self.assertEqual(ops.shape(image), ops.shape(output_image)) + self.assertEqual(ops.shape(label), ops.shape(output_label)) + self.assertEqual(image.dtype, output_image.dtype) + self.assertEqual(label.dtype, output_label.dtype) + + def test_serialization(self): + layer = RandomElasticDeformation3D( + grid_size=(3, 3, 3), + alpha=50.0, + sigma=5.0, + seed=0, + ) + image_data = ops.ones((2, 16, 16, 16, 3), dtype="float32") + label_data = ops.ones((2, 16, 16, 16, 1), dtype="int32") + input_data = (image_data, label_data) + + image_input = Input(shape=(16, 16, 16, 3), dtype="float32") + label_input = Input(shape=(16, 16, 16, 1), dtype="int32") + outputs = layer((image_input, label_input)) + model = Model(inputs=[image_input, label_input], outputs=outputs) + + original_output_image, original_output_label = model(input_data) + + path = os.path.join(self.get_temp_dir(), "model.keras") + model.save(path) + loaded_model = keras.models.load_model( + path, + custom_objects={ + "RandomElasticDeformation3D": RandomElasticDeformation3D + }, + ) + + loaded_output_image, loaded_output_label = loaded_model(input_data) + + np.testing.assert_allclose( + ops.convert_to_numpy(original_output_image), + ops.convert_to_numpy(loaded_output_image), + atol=1e-6, + ) + np.testing.assert_array_equal( + ops.convert_to_numpy(original_output_label), + ops.convert_to_numpy(loaded_output_label), + ) + + def test_label_values_are_preserved(self): + utils.set_random_seed(0) + image = ops.zeros(shape=(1, 16, 16, 16, 1), dtype="float32") + label_arange = ops.arange(16**3, dtype="int32") + label = ops.reshape(label_arange, (1, 16, 16, 16, 1)) % 4 + layer = RandomElasticDeformation3D(alpha=80.0, sigma=8.0) + _, output_label = layer((image, label)) + output_values = set(np.unique(ops.convert_to_numpy(output_label))) + expected_values = {0, 1, 2, 3} + self.assertLessEqual(output_values, expected_values)