-
Notifications
You must be signed in to change notification settings - Fork 301
Description
Hello KerasCV Team,
My name is Ashpak Shaikh, and I'm a developer working extensively with 3D medical image segmentation. I'd like to propose and contribute a new augmentation layer that I believe would be a valuable addition to KerasCV's growing 3D capabilities: RandomElasticDeformation3D.
Motivation
Elastic deformation is a state-of-the-art and highly effective augmentation technique for tasks involving non-rigid subjects, such as those in medical imaging. It simulates realistic tissue distortions, making models more robust. While KerasCV has a fantastic suite of 2D augmentations, a pure, graph-native 3D elastic deformation layer is currently missing. Implementing this is a critical step for building SOTA 3D pipelines (inspired by frameworks like nnU-Net) entirely within the TensorFlow/Keras ecosystem.
To fill this gap, I have developed a RandomElasticDeformation3D layer and would be excited to contribute it.
Implementation Details & Features
The layer is designed from the ground up to be a robust, production-ready component that aligns with KerasCV's design principles.
Pure TensorFlow Native: The entire implementation uses only native TensorFlow operations. It has no dependencies on external libraries like SciPy or TensorFlow Addons, ensuring it is fully graph-compatible and serializable.
Follows SOTA Methodology: It implements the standard "coarse-grid -> smooth -> warp" technique. It creates a low-resolution random displacement field, smooths it with a 3D Gaussian filter, and then uses a custom 3D resampling function to warp the volume.
Keras Layer Subclass: It is built as a standard tf.keras.layers.Layer, so it integrates seamlessly into any Keras workflow (Sequential models, Functional API, model.fit, etc.).
Flexible Data Format: It supports both data_format="DHWC" (Depth, Height, Width, Channels) and data_format="HWDC" (Height, Width, Depth, Channels) via a constructor argument.
Handles Batched & Unbatched Inputs: The layer transparently handles both 4D single-volume tensors and 5D batched tensors.
Synchronized Image & Label Augmentation: The layer takes a tuple of (image, label) as input. It correctly applies the exact same deformation to both, using trilinear interpolation for the image and nearest-neighbor interpolation for the segmentation mask to preserve label integrity.
Code
I have prepared the complete, self-contained code for the layer in the following GitHub Gist for your review:
Link to RandomElasticDeformation3D.py Gist
Next Steps
I am very keen to contribute this to the community through KerasCV. I am fully prepared to work with the team to add comprehensive docstrings, unit tests, and any other modifications necessary to meet the library's contribution standards and prepare a formal Pull Request.
Thank you for your time and consideration. I look forward to your feedback and guidance.
Best regards,
Ashpak Shaikh