Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions keras_hub/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,7 @@
from keras_hub.src.models.xception.xception_image_converter import (
XceptionImageConverter as XceptionImageConverter,
)

from keras_hub.src.layers.preprocessing.random_elastic_deformation_3d import (
RandomElasticDeformation3D,
)
127 changes: 127 additions & 0 deletions keras_hub/src/layers/preprocessing/random_elastic_deformation_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import tensorflow as tf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The implementation uses the tensorflow package directly, which violates the backend-agnostic principle of Keras Hub.1 All code must support TensorFlow, JAX, and PyTorch. Please refactor the layer to use keras.ops and keras.layers instead of tf.* functions.

For example:

  • import tensorflow as tf should be replaced with from keras import ops and from keras import layers.
  • tf.keras.layers.Layer should be layers.Layer.
  • tf.constant should be ops.convert_to_tensor.
  • tf.nn.convolution should be replaced with ops.conv.
  • tf.image.resize should be ops.image.resize.
  • ... and so on for all other tf calls.

This is a fundamental requirement for all contributions.

Style Guide References

Footnotes

  1. All code must be keras 3 backend-agnostic, supporting TensorFlow, JAX, and PyTorch backends.


class RandomElasticDeformation3D(tf.keras.layers.Layer):
"""
A high-performance 3D elastic deformation layer optimized for TPUs and GPUs.
... (docstring is the same) ...
"""
def __init__(self,
grid_size=(4, 4, 4),
alpha=35.0,
sigma=2.5,
data_format="DHWC",
**kwargs):
super().__init__(**kwargs)
self.grid_size = grid_size
self.alpha = tf.constant(alpha, dtype=tf.bfloat16)
self.sigma = tf.constant(sigma, dtype=tf.bfloat16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This __init__ method has a couple of issues regarding style guide compliance and flexibility:

  1. It violates the style guide by not storing the original alpha and sigma arguments as attributes.1 They are immediately converted to tensors, which prevents the layer from being serializable because a get_config() method cannot be correctly implemented.2
  2. The dtype is hardcoded to bfloat16. It's better to use the layer's compute_dtype to respect the model's overall dtype policy.

Please refactor __init__ to address these points. You will also need to add a get_config() method to the class. You would then need to update the call method to use internal tensor attributes (e.g., self._alpha_tensor).

Suggested change
self.alpha = tf.constant(alpha, dtype=tf.bfloat16)
self.sigma = tf.constant(sigma, dtype=tf.bfloat16)
self.alpha = alpha
self.sigma = sigma

Style Guide References

Footnotes

  1. Keep Python attributes on the layer for each __init__ argument to the layer. The name and value should match the passed value.

  2. Write a get_config() which chains to super.

if data_format not in ["DHWC", "HWDC"]:
raise ValueError("`data_format` must be one of 'DHWC' or 'HWDC'")
self.data_format = data_format

def _separable_gaussian_filter_3d(self, tensor, sigma):

kernel_size = tf.cast(2 * tf.round(3 * sigma) + 1, dtype=tf.int32)
ax = tf.range(-tf.cast(kernel_size // 2, tf.bfloat16) + 1.0,
tf.cast(kernel_size // 2, tf.bfloat16) + 1.0)
kernel_1d = tf.exp(-(ax**2) / (2.0 * self.sigma**2))
kernel_1d = kernel_1d / tf.reduce_sum(kernel_1d)
filter_d = tf.cast(tf.reshape(kernel_1d, [-1, 1, 1, 1, 1]), dtype=tensor.dtype)
filter_h = tf.cast(tf.reshape(kernel_1d, [1, -1, 1, 1, 1]), dtype=tensor.dtype)
filter_w = tf.cast(tf.reshape(kernel_1d, [1, 1, -1, 1, 1]), dtype=tensor.dtype)
tensor = tf.nn.convolution(tensor, filter_d, strides=1, padding='SAME')
tensor = tf.nn.convolution(tensor, filter_h, strides=1, padding='SAME')
tensor = tf.nn.convolution(tensor, filter_w, strides=1, padding='SAME')
return tensor

def call(self, inputs):
image_volume, label_volume = inputs
original_image_dtype = image_volume.dtype

was_batched = True
if image_volume.shape.rank == 4:
was_batched = False
image_volume = tf.expand_dims(image_volume, axis=0)
label_volume = tf.expand_dims(label_volume, axis=0)

if self.data_format == "HWDC":
image_volume = tf.transpose(image_volume, perm=[0, 3, 1, 2, 4])
label_volume = tf.transpose(label_volume, perm=[0, 3, 1, 2, 4])

image_volume = tf.cast(image_volume, dtype=tf.bfloat16)
input_shape = tf.shape(image_volume)
B, D, H, W = input_shape[0], input_shape[1], input_shape[2], input_shape[3]

coarse_flow = tf.random.uniform(
shape=(B, self.grid_size[0], self.grid_size[1], self.grid_size[2], 3),
minval=-1, maxval=1, dtype=tf.bfloat16)

flow = tf.reshape(coarse_flow, [B * self.grid_size[0], self.grid_size[1], self.grid_size[2], 3])
flow = tf.image.resize(flow, size=[H, W], method='bicubic')
flow = tf.reshape(flow, [B, self.grid_size[0], H, W, 3])
flow = tf.transpose(flow, perm=[0, 2, 3, 1, 4])
flow = tf.reshape(flow, [B * H * W, self.grid_size[0], 3])
flow = tf.image.resize(tf.expand_dims(flow, axis=1), size=[1, D], method='bicubic')
flow = tf.squeeze(flow, axis=1)
flow = tf.reshape(flow, [B, H, W, D, 3])
flow = tf.transpose(flow, perm=[0, 3, 1, 2, 4])


flow = tf.cast(flow, dtype=tf.bfloat16)

flow_components = tf.unstack(flow, axis=-1)
smoothed_components = []
for component in flow_components:
smoothed_component = self._separable_gaussian_filter_3d(
component[..., tf.newaxis], self.sigma
)
smoothed_components.append(smoothed_component[..., 0])
smoothed_flow = tf.stack(smoothed_components, axis=-1)


flow = smoothed_flow * self.alpha

grid_d, grid_h, grid_w = tf.meshgrid(
tf.range(D, dtype=tf.bfloat16),
tf.range(H, dtype=tf.bfloat16),
tf.range(W, dtype=tf.bfloat16),
indexing='ij'
)
grid = tf.stack([grid_d, grid_h, grid_w], axis=-1)


warp_grid = tf.expand_dims(grid, 0) + flow

warp_grid_floor = tf.floor(warp_grid)
t = warp_grid - warp_grid_floor

d0 = tf.cast(warp_grid_floor[..., 0], tf.int32); h0 = tf.cast(warp_grid_floor[..., 1], tf.int32); w0 = tf.cast(warp_grid_floor[..., 2], tf.int32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Placing multiple statements on a single line using semicolons is discouraged as it harms readability. Please split these assignments onto separate lines.

Suggested change
d0 = tf.cast(warp_grid_floor[..., 0], tf.int32); h0 = tf.cast(warp_grid_floor[..., 1], tf.int32); w0 = tf.cast(warp_grid_floor[..., 2], tf.int32)
d0 = tf.cast(warp_grid_floor[..., 0], tf.int32)
h0 = tf.cast(warp_grid_floor[..., 1], tf.int32)
w0 = tf.cast(warp_grid_floor[..., 2], tf.int32)

d1 = tf.clip_by_value(d0 + 1, 0, D - 1); h1 = tf.clip_by_value(h0 + 1, 0, H - 1); w1 = tf.clip_by_value(w0 + 1, 0, W - 1)
d0 = tf.clip_by_value(d0, 0, D - 1); h0 = tf.clip_by_value(h0, 0, H - 1); w0 = tf.clip_by_value(w0, 0, W - 1)

c000 = tf.gather_nd(image_volume, tf.stack([d0, h0, w0], axis=-1), batch_dims=1); c001 = tf.gather_nd(image_volume, tf.stack([d0, h0, w1], axis=-1), batch_dims=1)
c010 = tf.gather_nd(image_volume, tf.stack([d0, h1, w0], axis=-1), batch_dims=1); c011 = tf.gather_nd(image_volume, tf.stack([d0, h1, w1], axis=-1), batch_dims=1)
c100 = tf.gather_nd(image_volume, tf.stack([d1, h0, w0], axis=-1), batch_dims=1); c101 = tf.gather_nd(image_volume, tf.stack([d1, h0, w1], axis=-1), batch_dims=1)
c110 = tf.gather_nd(image_volume, tf.stack([d1, h1, w0], axis=-1), batch_dims=1); c111 = tf.gather_nd(image_volume, tf.stack([d1, h1, w1], axis=-1), batch_dims=1)

td, th, tw = t[..., 0:1], t[..., 1:2], t[..., 2:3]
c00 = c000*(1-tw) + c001*tw; c01 = c010*(1-tw) + c011*tw; c10 = c100*(1-tw) + c101*tw; c11 = c110*(1-tw) + c111*tw
c0 = c00*(1-th) + c01*th; c1 = c10*(1-th) + c11*th
deformed_image = c0*(1-td) + c1*td
deformed_image = tf.cast(deformed_image, original_image_dtype)

nearest_indices_float = tf.round(warp_grid)
nearest_d = tf.clip_by_value(tf.cast(nearest_indices_float[..., 0], tf.int32), 0, D - 1)
nearest_h = tf.clip_by_value(tf.cast(nearest_indices_float[..., 1], tf.int32), 0, H - 1)
nearest_w = tf.clip_by_value(tf.cast(nearest_indices_float[..., 2], tf.int32), 0, W - 1)
deformed_label = tf.gather_nd(label_volume, tf.stack([nearest_d, nearest_h, nearest_w], axis=-1), batch_dims=1)

if self.data_format == "HWDC":
deformed_image = tf.transpose(deformed_image, perm=[0, 2, 3, 1, 4])
deformed_label = tf.transpose(deformed_label, perm=[0, 2, 3, 1, 4])

if not was_batched:
deformed_image = tf.squeeze(deformed_image, axis=0)
deformed_label = tf.squeeze(deformed_label, axis=0)

return deformed_image, deformed_label
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import tensorflow as tf
from tensorflow import keras
from keras_hub.src.layers.preprocessing.random_elastic_deformation_3d import RandomElasticDeformation3D

class RandomElasticDeformation3DTest(tf.test.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This test class should inherit from keras_hub.src.tests.test_case.TestCase instead of tf.test.TestCase. The Keras Hub testing framework provides standardized helper methods that should be used to ensure consistent and thorough testing across the library.1

Please refactor the tests to use self.run_layer_test() for basic checks and self.run_model_saving_test() for serialization, as outlined in the contribution guidelines.2

Style Guide References

Footnotes

  1. KerasHub provides helper methods in the TestCase class that handle the standardized test routines. Users should use these methods instead of writing tests from scratch.

  2. Available Test Helper Methods: self.run_layer_test(), self.run_model_saving_test()


def test_output_shape_is_same_as_input_dhwc(self):
input_image = tf.random.uniform(shape=(2, 32, 64, 64, 3), dtype=tf.float32)
input_label = tf.random.uniform(shape=(2, 32, 64, 64, 1), maxval=4, dtype=tf.int32)
layer = RandomElasticDeformation3D(data_format="DHWC")
output_image, output_label = layer((input_image, tf.cast(input_label, tf.float32)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The test casts integer labels to tf.float32 before passing them to the layer. For segmentation masks, it's common to use integer types. The layer should ideally support integer-dtype labels directly, and the tests should reflect that. The nearest-neighbor interpolation for labels should work correctly with integer types.

self.assertAllEqual(tf.shape(input_image), tf.shape(output_image))
self.assertAllEqual(tf.shape(input_label), tf.shape(output_label))

def test_output_shape_is_same_as_input_hwdc(self):
input_image = tf.random.uniform(shape=(2, 64, 64, 32, 3), dtype=tf.float32)
input_label = tf.random.uniform(shape=(2, 64, 64, 32, 1), maxval=4, dtype=tf.int32)
layer = RandomElasticDeformation3D(data_format="HWDC")
output_image, output_label = layer((input_image, tf.cast(input_label, tf.float32)))
self.assertAllEqual(tf.shape(input_image), tf.shape(output_image))
self.assertAllEqual(tf.shape(input_label), tf.shape(output_label))

def test_unbatched_input(self):
input_image = tf.random.uniform(shape=(32, 64, 64, 3), dtype=tf.float32)
input_label = tf.random.uniform(shape=(32, 64, 64, 1), maxval=4, dtype=tf.int32)
layer = RandomElasticDeformation3D(data_format="DHWC")
output_image, output_label = layer((input_image, tf.cast(input_label, tf.float32)))
self.assertAllEqual(tf.shape(input_image), tf.shape(output_image))
self.assertEqual(tf.rank(output_image), 4)

def test_dtype_preservation(self):
input_image = tf.random.uniform(shape=(2, 16, 16, 16, 3), dtype=tf.float32)
input_label = tf.random.uniform(shape=(2, 16, 16, 16, 1), maxval=4, dtype=tf.int32)
layer = RandomElasticDeformation3D()
output_image, output_label = layer((input_image, tf.cast(input_label, tf.float32)))
self.assertEqual(output_image.dtype, tf.float32)
self.assertEqual(output_label.dtype, tf.float32)

def test_label_values_are_preserved(self):
input_image = tf.zeros(shape=(1, 16, 16, 16, 1), dtype=tf.float32)
label_arange = tf.experimental.numpy.arange(16**3)
input_label = tf.reshape(label_arange, (1, 16, 16, 16, 1))
input_label = tf.cast(input_label, dtype=tf.float32) % 4

layer = RandomElasticDeformation3D(alpha=80.0, sigma=8.0)
_, output_label = layer((input_image, input_label))

unique_values_tensor = tf.unique(tf.reshape(output_label, [-1]))[0]


expected_values = [0., 1., 2., 3.]
actual_values = unique_values_tensor.numpy().tolist()
self.assertContainsSubset(expected_values, actual_values)

def test_config_serialization(self):
layer = RandomElasticDeformation3D(
grid_size=(3, 3, 3),
alpha=50.0,
sigma=5.0,
data_format="HWDC"
)
config = layer.get_config()
new_layer = RandomElasticDeformation3D.from_config(config)
self.assertEqual(new_layer.grid_size, (3, 3, 3))
self.assertAllClose(new_layer.alpha, tf.constant(50.0, dtype=tf.bfloat16))
self.assertAllClose(new_layer.sigma, tf.constant(5.0, dtype=tf.bfloat16))
self.assertEqual(new_layer.data_format, "HWDC")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This serialization test will fail because the get_config method is not implemented in the RandomElasticDeformation3D layer. Per the style guide, you should use self.run_model_saving_test() to test serialization, which is more comprehensive.1

Style Guide References

Footnotes

  1. Use self.run_model_saving_test() for testing model serialization.


if __name__ == "__main__":
tf.test.main()
Loading