-
Notifications
You must be signed in to change notification settings - Fork 301
feat(layers): Add TPU-optimized 3D Elastic Deformation layer #2419
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 1 commit
3665ebb
9793ab9
cb92954
fa19ac9
4cf6ee0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,127 @@ | ||||||||||
import tensorflow as tf | ||||||||||
|
||||||||||
class RandomElasticDeformation3D(tf.keras.layers.Layer): | ||||||||||
""" | ||||||||||
A high-performance 3D elastic deformation layer optimized for TPUs and GPUs. | ||||||||||
... (docstring is the same) ... | ||||||||||
""" | ||||||||||
def __init__(self, | ||||||||||
grid_size=(4, 4, 4), | ||||||||||
alpha=35.0, | ||||||||||
sigma=2.5, | ||||||||||
data_format="DHWC", | ||||||||||
**kwargs): | ||||||||||
super().__init__(**kwargs) | ||||||||||
self.grid_size = grid_size | ||||||||||
self.alpha = tf.constant(alpha, dtype=tf.bfloat16) | ||||||||||
self.sigma = tf.constant(sigma, dtype=tf.bfloat16) | ||||||||||
|
self.alpha = tf.constant(alpha, dtype=tf.bfloat16) | |
self.sigma = tf.constant(sigma, dtype=tf.bfloat16) | |
self.alpha = alpha | |
self.sigma = sigma |
Style Guide References
Footnotes
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Placing multiple statements on a single line using semicolons is discouraged as it harms readability. Please split these assignments onto separate lines.
d0 = tf.cast(warp_grid_floor[..., 0], tf.int32); h0 = tf.cast(warp_grid_floor[..., 1], tf.int32); w0 = tf.cast(warp_grid_floor[..., 2], tf.int32) | |
d0 = tf.cast(warp_grid_floor[..., 0], tf.int32) | |
h0 = tf.cast(warp_grid_floor[..., 1], tf.int32) | |
w0 = tf.cast(warp_grid_floor[..., 2], tf.int32) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import tensorflow as tf | ||
from tensorflow import keras | ||
from keras_hub.src.layers.preprocessing.random_elastic_deformation_3d import RandomElasticDeformation3D | ||
|
||
class RandomElasticDeformation3DTest(tf.test.TestCase): | ||
|
||
|
||
def test_output_shape_is_same_as_input_dhwc(self): | ||
input_image = tf.random.uniform(shape=(2, 32, 64, 64, 3), dtype=tf.float32) | ||
input_label = tf.random.uniform(shape=(2, 32, 64, 64, 1), maxval=4, dtype=tf.int32) | ||
layer = RandomElasticDeformation3D(data_format="DHWC") | ||
output_image, output_label = layer((input_image, tf.cast(input_label, tf.float32))) | ||
|
||
self.assertAllEqual(tf.shape(input_image), tf.shape(output_image)) | ||
self.assertAllEqual(tf.shape(input_label), tf.shape(output_label)) | ||
|
||
def test_output_shape_is_same_as_input_hwdc(self): | ||
input_image = tf.random.uniform(shape=(2, 64, 64, 32, 3), dtype=tf.float32) | ||
input_label = tf.random.uniform(shape=(2, 64, 64, 32, 1), maxval=4, dtype=tf.int32) | ||
layer = RandomElasticDeformation3D(data_format="HWDC") | ||
output_image, output_label = layer((input_image, tf.cast(input_label, tf.float32))) | ||
self.assertAllEqual(tf.shape(input_image), tf.shape(output_image)) | ||
self.assertAllEqual(tf.shape(input_label), tf.shape(output_label)) | ||
|
||
def test_unbatched_input(self): | ||
input_image = tf.random.uniform(shape=(32, 64, 64, 3), dtype=tf.float32) | ||
input_label = tf.random.uniform(shape=(32, 64, 64, 1), maxval=4, dtype=tf.int32) | ||
layer = RandomElasticDeformation3D(data_format="DHWC") | ||
output_image, output_label = layer((input_image, tf.cast(input_label, tf.float32))) | ||
self.assertAllEqual(tf.shape(input_image), tf.shape(output_image)) | ||
self.assertEqual(tf.rank(output_image), 4) | ||
|
||
def test_dtype_preservation(self): | ||
input_image = tf.random.uniform(shape=(2, 16, 16, 16, 3), dtype=tf.float32) | ||
input_label = tf.random.uniform(shape=(2, 16, 16, 16, 1), maxval=4, dtype=tf.int32) | ||
layer = RandomElasticDeformation3D() | ||
output_image, output_label = layer((input_image, tf.cast(input_label, tf.float32))) | ||
self.assertEqual(output_image.dtype, tf.float32) | ||
self.assertEqual(output_label.dtype, tf.float32) | ||
|
||
def test_label_values_are_preserved(self): | ||
input_image = tf.zeros(shape=(1, 16, 16, 16, 1), dtype=tf.float32) | ||
label_arange = tf.experimental.numpy.arange(16**3) | ||
ashpakshaikh26732 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
input_label = tf.reshape(label_arange, (1, 16, 16, 16, 1)) | ||
input_label = tf.cast(input_label, dtype=tf.float32) % 4 | ||
|
||
layer = RandomElasticDeformation3D(alpha=80.0, sigma=8.0) | ||
_, output_label = layer((input_image, input_label)) | ||
|
||
unique_values_tensor = tf.unique(tf.reshape(output_label, [-1]))[0] | ||
|
||
|
||
expected_values = [0., 1., 2., 3.] | ||
actual_values = unique_values_tensor.numpy().tolist() | ||
self.assertContainsSubset(expected_values, actual_values) | ||
|
||
def test_config_serialization(self): | ||
layer = RandomElasticDeformation3D( | ||
grid_size=(3, 3, 3), | ||
alpha=50.0, | ||
sigma=5.0, | ||
data_format="HWDC" | ||
) | ||
config = layer.get_config() | ||
new_layer = RandomElasticDeformation3D.from_config(config) | ||
self.assertEqual(new_layer.grid_size, (3, 3, 3)) | ||
self.assertAllClose(new_layer.alpha, tf.constant(50.0, dtype=tf.bfloat16)) | ||
self.assertAllClose(new_layer.sigma, tf.constant(5.0, dtype=tf.bfloat16)) | ||
self.assertEqual(new_layer.data_format, "HWDC") | ||
|
||
|
||
if __name__ == "__main__": | ||
tf.test.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation uses the
tensorflow
package directly, which violates the backend-agnostic principle of Keras Hub.1 All code must support TensorFlow, JAX, and PyTorch. Please refactor the layer to usekeras.ops
andkeras.layers
instead oftf.*
functions.For example:
import tensorflow as tf
should be replaced withfrom keras import ops
andfrom keras import layers
.tf.keras.layers.Layer
should belayers.Layer
.tf.constant
should beops.convert_to_tensor
.tf.nn.convolution
should be replaced withops.conv
.tf.image.resize
should beops.image.resize
.tf
calls.This is a fundamental requirement for all contributions.
Style Guide References
Footnotes
All code must be keras 3 backend-agnostic, supporting TensorFlow, JAX, and PyTorch backends. ↩