Skip to content

Contribution: A Native 3D Elastic Deformation Augmentation Layer #2409

@ashpakshaikh26732

Description

@ashpakshaikh26732

Hello KerasCV Team,

My name is Ashpak Shaikh, and I'm a developer working extensively with 3D medical image segmentation. I'd like to propose and contribute a new augmentation layer that I believe would be a valuable addition to KerasCV's growing 3D capabilities: RandomElasticDeformation3D.

Motivation
Elastic deformation is a state-of-the-art and highly effective augmentation technique for tasks involving non-rigid subjects, such as those in medical imaging. It simulates realistic tissue distortions, making models more robust. While KerasCV has a fantastic suite of 2D augmentations, a pure, graph-native 3D elastic deformation layer is currently missing. Implementing this is a critical step for building SOTA 3D pipelines (inspired by frameworks like nnU-Net) entirely within the TensorFlow/Keras ecosystem.

To fill this gap, I have developed a RandomElasticDeformation3D layer and would be excited to contribute it.

Implementation Details & Features
The layer is designed from the ground up to be a robust, production-ready component that aligns with KerasCV's design principles.

Pure TensorFlow Native: The entire implementation uses only native TensorFlow operations. It has no dependencies on external libraries like SciPy or TensorFlow Addons, ensuring it is fully graph-compatible and serializable.

Follows SOTA Methodology: It implements the standard "coarse-grid -> smooth -> warp" technique. It creates a low-resolution random displacement field, smooths it with a 3D Gaussian filter, and then uses a custom 3D resampling function to warp the volume.

Keras Layer Subclass: It is built as a standard tf.keras.layers.Layer, so it integrates seamlessly into any Keras workflow (Sequential models, Functional API, model.fit, etc.).

Flexible Data Format: It supports both data_format="DHWC" (Depth, Height, Width, Channels) and data_format="HWDC" (Height, Width, Depth, Channels) via a constructor argument.

Handles Batched & Unbatched Inputs: The layer transparently handles both 4D single-volume tensors and 5D batched tensors.

Synchronized Image & Label Augmentation: The layer takes a tuple of (image, label) as input. It correctly applies the exact same deformation to both, using trilinear interpolation for the image and nearest-neighbor interpolation for the segmentation mask to preserve label integrity.

Code
I have prepared the complete, self-contained code for the layer in the following GitHub Gist for your review:

Link to RandomElasticDeformation3D.py Gist

Next Steps
I am very keen to contribute this to the community through KerasCV. I am fully prepared to work with the team to add comprehensive docstrings, unit tests, and any other modifications necessary to meet the library's contribution standards and prepare a formal Pull Request.

Thank you for your time and consideration. I look forward to your feedback and guidance.

Best regards,
Ashpak Shaikh

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions