Skip to content

Conversation

ashpakshaikh26732
Copy link

Hello Keras Hub Team,

This PR adds the RandomElasticDeformation3D layer as proposed and approved in issue #2409.

Description
This new layer provides a state-of-the-art, TPU-optimized 3D elastic deformation augmentation. It is a pure, native TensorFlow implementation, making it fully graph-compatible and suitable for high-performance pipelines.

Key features include:

  • A high-performance implementation using bfloat16 and a fast, separable 3D Gaussian filter.
  • Support for both batched (5D) and unbatched (4D) inputs.
  • Flexible data_format handling ("DHWC" and "HWDC").
  • Correctly uses trilinear interpolation for images and nearest-neighbor for segmentation masks.

I have included a comprehensive set of unit tests, which all pass locally, to ensure correctness and robustness. Looking forward to your feedback!

Fixes #2409

Copy link

google-cla bot commented Sep 25, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Copy link
Contributor

Summary of Changes

Hello @ashpakshaikh26732, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates a new RandomElasticDeformation3D layer into Keras Hub, designed to perform advanced 3D elastic deformation augmentation. This layer is built with a focus on performance, leveraging native TensorFlow operations and bfloat16 for TPU and GPU efficiency. It provides essential data augmentation capabilities for 3D data, such as medical images, by introducing realistic, non-linear deformations while maintaining data integrity for both images and their corresponding segmentation masks.

Highlights

  • New 3D Elastic Deformation Layer: Introduces RandomElasticDeformation3D, a new Keras layer for 3D elastic deformation augmentation, addressing issue Contribution: A Native 3D Elastic Deformation Augmentation Layer #2409.
  • TPU-Optimized Implementation: The layer is a pure, native TensorFlow implementation, optimized for high performance on TPUs and GPUs, utilizing bfloat16 and a fast separable 3D Gaussian filter.
  • Flexible Input Handling: Supports both batched (5D) and unbatched (4D) inputs, along with flexible data_format handling for "DHWC" and "HWDC".
  • Interpolation Methods: Correctly applies trilinear interpolation for image volumes and nearest-neighbor interpolation for segmentation masks during deformation.
  • Comprehensive Unit Tests: Includes a thorough set of unit tests to ensure the correctness and robustness of the new layer, covering various input scenarios and configuration serialization.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR introduces a new RandomElasticDeformation3D layer, which is a great addition. The implementation is clever, especially the use of separable filters and resizing for performance. However, there are a few major issues that need to be addressed before this can be merged:

  1. Backend Agnosticism: The implementation is TensorFlow-specific, but the Keras Hub style guide requires all contributions to be backend-agnostic (supporting TensorFlow, JAX, and PyTorch). This means all tensorflow API calls must be replaced with their keras.ops or keras.layers equivalents.
  2. Testing Framework: The tests do not use the project's standard testing framework (keras_hub.src.tests.test_case.TestCase and its helper methods like run_layer_test). Adhering to the testing guidelines is crucial for maintainability.
  3. Serialization: The layer is missing the get_config method and stores __init__ arguments in a way that prevents serialization, which is required for all layers.
  4. Validation Notebook: The style guide requires a Colab notebook demonstrating numerical equivalence with an original implementation if one exists. Please add a link to it in the PR description.

I've left specific comments with suggestions on how to address these points. Great work on the core algorithm, and I look forward to seeing the updated version!

@@ -0,0 +1,127 @@
import tensorflow as tf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The implementation uses the tensorflow package directly, which violates the backend-agnostic principle of Keras Hub.1 All code must support TensorFlow, JAX, and PyTorch. Please refactor the layer to use keras.ops and keras.layers instead of tf.* functions.

For example:

  • import tensorflow as tf should be replaced with from keras import ops and from keras import layers.
  • tf.keras.layers.Layer should be layers.Layer.
  • tf.constant should be ops.convert_to_tensor.
  • tf.nn.convolution should be replaced with ops.conv.
  • tf.image.resize should be ops.image.resize.
  • ... and so on for all other tf calls.

This is a fundamental requirement for all contributions.

Style Guide References

Footnotes

  1. All code must be keras 3 backend-agnostic, supporting TensorFlow, JAX, and PyTorch backends.

from tensorflow import keras
from keras_hub.src.layers.preprocessing.random_elastic_deformation_3d import RandomElasticDeformation3D

class RandomElasticDeformation3DTest(tf.test.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This test class should inherit from keras_hub.src.tests.test_case.TestCase instead of tf.test.TestCase. The Keras Hub testing framework provides standardized helper methods that should be used to ensure consistent and thorough testing across the library.1

Please refactor the tests to use self.run_layer_test() for basic checks and self.run_model_saving_test() for serialization, as outlined in the contribution guidelines.2

Style Guide References

Footnotes

  1. KerasHub provides helper methods in the TestCase class that handle the standardized test routines. Users should use these methods instead of writing tests from scratch.

  2. Available Test Helper Methods: self.run_layer_test(), self.run_model_saving_test()

Comment on lines 16 to 17
self.alpha = tf.constant(alpha, dtype=tf.bfloat16)
self.sigma = tf.constant(sigma, dtype=tf.bfloat16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This __init__ method has a couple of issues regarding style guide compliance and flexibility:

  1. It violates the style guide by not storing the original alpha and sigma arguments as attributes.1 They are immediately converted to tensors, which prevents the layer from being serializable because a get_config() method cannot be correctly implemented.2
  2. The dtype is hardcoded to bfloat16. It's better to use the layer's compute_dtype to respect the model's overall dtype policy.

Please refactor __init__ to address these points. You will also need to add a get_config() method to the class. You would then need to update the call method to use internal tensor attributes (e.g., self._alpha_tensor).

Suggested change
self.alpha = tf.constant(alpha, dtype=tf.bfloat16)
self.sigma = tf.constant(sigma, dtype=tf.bfloat16)
self.alpha = alpha
self.sigma = sigma

Style Guide References

Footnotes

  1. Keep Python attributes on the layer for each __init__ argument to the layer. The name and value should match the passed value.

  2. Write a get_config() which chains to super.

Comment on lines 55 to 67
def test_config_serialization(self):
layer = RandomElasticDeformation3D(
grid_size=(3, 3, 3),
alpha=50.0,
sigma=5.0,
data_format="HWDC"
)
config = layer.get_config()
new_layer = RandomElasticDeformation3D.from_config(config)
self.assertEqual(new_layer.grid_size, (3, 3, 3))
self.assertAllClose(new_layer.alpha, tf.constant(50.0, dtype=tf.bfloat16))
self.assertAllClose(new_layer.sigma, tf.constant(5.0, dtype=tf.bfloat16))
self.assertEqual(new_layer.data_format, "HWDC")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This serialization test will fail because the get_config method is not implemented in the RandomElasticDeformation3D layer. Per the style guide, you should use self.run_model_saving_test() to test serialization, which is more comprehensive.1

Style Guide References

Footnotes

  1. Use self.run_model_saving_test() for testing model serialization.

warp_grid_floor = tf.floor(warp_grid)
t = warp_grid - warp_grid_floor

d0 = tf.cast(warp_grid_floor[..., 0], tf.int32); h0 = tf.cast(warp_grid_floor[..., 1], tf.int32); w0 = tf.cast(warp_grid_floor[..., 2], tf.int32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Placing multiple statements on a single line using semicolons is discouraged as it harms readability. Please split these assignments onto separate lines.

Suggested change
d0 = tf.cast(warp_grid_floor[..., 0], tf.int32); h0 = tf.cast(warp_grid_floor[..., 1], tf.int32); w0 = tf.cast(warp_grid_floor[..., 2], tf.int32)
d0 = tf.cast(warp_grid_floor[..., 0], tf.int32)
h0 = tf.cast(warp_grid_floor[..., 1], tf.int32)
w0 = tf.cast(warp_grid_floor[..., 2], tf.int32)

input_image = tf.random.uniform(shape=(2, 32, 64, 64, 3), dtype=tf.float32)
input_label = tf.random.uniform(shape=(2, 32, 64, 64, 1), maxval=4, dtype=tf.int32)
layer = RandomElasticDeformation3D(data_format="DHWC")
output_image, output_label = layer((input_image, tf.cast(input_label, tf.float32)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The test casts integer labels to tf.float32 before passing them to the layer. For segmentation masks, it's common to use integer types. The layer should ideally support integer-dtype labels directly, and the tests should reflect that. The nearest-neighbor interpolation for labels should work correctly with integer types.

@ashpakshaikh26732 ashpakshaikh26732 force-pushed the add_3d_elastic_deformation branch 2 times, most recently from dc925ce to 88836c8 Compare September 25, 2025 17:51
@ashpakshaikh26732 ashpakshaikh26732 force-pushed the add_3d_elastic_deformation branch from 88836c8 to 4cf6ee0 Compare September 26, 2025 04:21
@ashpakshaikh26732
Copy link
Author

Hi Keras team,

It looks like my changes for the RandomElasticDeformation3D layer are complete. All the tests related to formatting, serialization, and backend compatibility for my layer seem to be passing now.

The CI is still failing on one last test, but the error log refers to Arange and T5/Moonshine models, which are unrelated to my changes. This appears to be a pre-existing issue in the main branch.

Could you please take a look and perhaps re-run the failing job?

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Contribution: A Native 3D Elastic Deformation Augmentation Layer
1 participant