-
Notifications
You must be signed in to change notification settings - Fork 301
feat(layers): Add TPU-optimized 3D Elastic Deformation layer #2419
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
feat(layers): Add TPU-optimized 3D Elastic Deformation layer #2419
Conversation
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
Summary of ChangesHello @ashpakshaikh26732, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates a new Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This PR introduces a new RandomElasticDeformation3D
layer, which is a great addition. The implementation is clever, especially the use of separable filters and resizing for performance. However, there are a few major issues that need to be addressed before this can be merged:
- Backend Agnosticism: The implementation is TensorFlow-specific, but the Keras Hub style guide requires all contributions to be backend-agnostic (supporting TensorFlow, JAX, and PyTorch). This means all
tensorflow
API calls must be replaced with theirkeras.ops
orkeras.layers
equivalents. - Testing Framework: The tests do not use the project's standard testing framework (
keras_hub.src.tests.test_case.TestCase
and its helper methods likerun_layer_test
). Adhering to the testing guidelines is crucial for maintainability. - Serialization: The layer is missing the
get_config
method and stores__init__
arguments in a way that prevents serialization, which is required for all layers. - Validation Notebook: The style guide requires a Colab notebook demonstrating numerical equivalence with an original implementation if one exists. Please add a link to it in the PR description.
I've left specific comments with suggestions on how to address these points. Great work on the core algorithm, and I look forward to seeing the updated version!
@@ -0,0 +1,127 @@ | |||
import tensorflow as tf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation uses the tensorflow
package directly, which violates the backend-agnostic principle of Keras Hub.1 All code must support TensorFlow, JAX, and PyTorch. Please refactor the layer to use keras.ops
and keras.layers
instead of tf.*
functions.
For example:
import tensorflow as tf
should be replaced withfrom keras import ops
andfrom keras import layers
.tf.keras.layers.Layer
should belayers.Layer
.tf.constant
should beops.convert_to_tensor
.tf.nn.convolution
should be replaced withops.conv
.tf.image.resize
should beops.image.resize
.- ... and so on for all other
tf
calls.
This is a fundamental requirement for all contributions.
Style Guide References
Footnotes
-
All code must be keras 3 backend-agnostic, supporting TensorFlow, JAX, and PyTorch backends. ↩
from tensorflow import keras | ||
from keras_hub.src.layers.preprocessing.random_elastic_deformation_3d import RandomElasticDeformation3D | ||
|
||
class RandomElasticDeformation3DTest(tf.test.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test class should inherit from keras_hub.src.tests.test_case.TestCase
instead of tf.test.TestCase
. The Keras Hub testing framework provides standardized helper methods that should be used to ensure consistent and thorough testing across the library.1
Please refactor the tests to use self.run_layer_test()
for basic checks and self.run_model_saving_test()
for serialization, as outlined in the contribution guidelines.2
Style Guide References
Footnotes
self.alpha = tf.constant(alpha, dtype=tf.bfloat16) | ||
self.sigma = tf.constant(sigma, dtype=tf.bfloat16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This __init__
method has a couple of issues regarding style guide compliance and flexibility:
- It violates the style guide by not storing the original
alpha
andsigma
arguments as attributes.1 They are immediately converted to tensors, which prevents the layer from being serializable because aget_config()
method cannot be correctly implemented.2 - The
dtype
is hardcoded tobfloat16
. It's better to use the layer'scompute_dtype
to respect the model's overall dtype policy.
Please refactor __init__
to address these points. You will also need to add a get_config()
method to the class. You would then need to update the call
method to use internal tensor attributes (e.g., self._alpha_tensor
).
self.alpha = tf.constant(alpha, dtype=tf.bfloat16) | |
self.sigma = tf.constant(sigma, dtype=tf.bfloat16) | |
self.alpha = alpha | |
self.sigma = sigma |
Style Guide References
Footnotes
def test_config_serialization(self): | ||
layer = RandomElasticDeformation3D( | ||
grid_size=(3, 3, 3), | ||
alpha=50.0, | ||
sigma=5.0, | ||
data_format="HWDC" | ||
) | ||
config = layer.get_config() | ||
new_layer = RandomElasticDeformation3D.from_config(config) | ||
self.assertEqual(new_layer.grid_size, (3, 3, 3)) | ||
self.assertAllClose(new_layer.alpha, tf.constant(50.0, dtype=tf.bfloat16)) | ||
self.assertAllClose(new_layer.sigma, tf.constant(5.0, dtype=tf.bfloat16)) | ||
self.assertEqual(new_layer.data_format, "HWDC") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This serialization test will fail because the get_config
method is not implemented in the RandomElasticDeformation3D
layer. Per the style guide, you should use self.run_model_saving_test()
to test serialization, which is more comprehensive.1
Style Guide References
Footnotes
-
Use
self.run_model_saving_test()
for testing model serialization. ↩
warp_grid_floor = tf.floor(warp_grid) | ||
t = warp_grid - warp_grid_floor | ||
|
||
d0 = tf.cast(warp_grid_floor[..., 0], tf.int32); h0 = tf.cast(warp_grid_floor[..., 1], tf.int32); w0 = tf.cast(warp_grid_floor[..., 2], tf.int32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Placing multiple statements on a single line using semicolons is discouraged as it harms readability. Please split these assignments onto separate lines.
d0 = tf.cast(warp_grid_floor[..., 0], tf.int32); h0 = tf.cast(warp_grid_floor[..., 1], tf.int32); w0 = tf.cast(warp_grid_floor[..., 2], tf.int32) | |
d0 = tf.cast(warp_grid_floor[..., 0], tf.int32) | |
h0 = tf.cast(warp_grid_floor[..., 1], tf.int32) | |
w0 = tf.cast(warp_grid_floor[..., 2], tf.int32) |
input_image = tf.random.uniform(shape=(2, 32, 64, 64, 3), dtype=tf.float32) | ||
input_label = tf.random.uniform(shape=(2, 32, 64, 64, 1), maxval=4, dtype=tf.int32) | ||
layer = RandomElasticDeformation3D(data_format="DHWC") | ||
output_image, output_label = layer((input_image, tf.cast(input_label, tf.float32))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test casts integer labels to tf.float32
before passing them to the layer. For segmentation masks, it's common to use integer types. The layer should ideally support integer-dtype labels directly, and the tests should reflect that. The nearest-neighbor interpolation for labels should work correctly with integer types.
keras_hub/src/layers/preprocessing/random_elastic_deformation_3d_test.py
Outdated
Show resolved
Hide resolved
dc925ce
to
88836c8
Compare
88836c8
to
4cf6ee0
Compare
Hi Keras team, It looks like my changes for the RandomElasticDeformation3D layer are complete. All the tests related to formatting, serialization, and backend compatibility for my layer seem to be passing now. The CI is still failing on one last test, but the error log refers to Arange and T5/Moonshine models, which are unrelated to my changes. This appears to be a pre-existing issue in the main branch. Could you please take a look and perhaps re-run the failing job? Thanks! |
Hello Keras Hub Team,
This PR adds the
RandomElasticDeformation3D
layer as proposed and approved in issue #2409.Description
This new layer provides a state-of-the-art, TPU-optimized 3D elastic deformation augmentation. It is a pure, native TensorFlow implementation, making it fully graph-compatible and suitable for high-performance pipelines.
Key features include:
bfloat16
and a fast, separable 3D Gaussian filter.data_format
handling ("DHWC"
and"HWDC"
).I have included a comprehensive set of unit tests, which all pass locally, to ensure correctness and robustness. Looking forward to your feedback!
Fixes #2409