Add train_transforms and val_transforms to train_timm_segmentation_model#606
Add train_transforms and val_transforms to train_timm_segmentation_model#606
Conversation
Expose the existing transform support in SegmentationDataset through the high-level train_timm_segmentation_model() API, matching the interface already available in train_segmentation_model(). Closes #605
There was a problem hiding this comment.
Pull request overview
Exposes dataset-level augmentation hooks in train_timm_segmentation_model() by adding train_transforms and val_transforms parameters and wiring them through to the underlying SegmentationDataset, aligning the high-level API more closely with train_segmentation_model().
Changes:
- Add
train_transformsandval_transformsparameters totrain_timm_segmentation_model(). - Document the expected transform callable signature and tensor formats.
- Pass the provided transforms into
SegmentationDataset(transform=...)for train/val splits.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| train_dataset = SegmentationDataset( | ||
| image_paths=train_images, | ||
| mask_paths=train_labels, | ||
| num_channels=num_channels, | ||
| transform=train_transforms, | ||
| ) |
There was a problem hiding this comment.
This change introduces new public parameters (train_transforms/val_transforms) and forwards them into SegmentationDataset, but there are no automated tests that assert the callables are actually passed through and invoked. Add a unit test (e.g., patch SegmentationDataset and/or provide a transform that mutates the tensors) to verify both train and val transforms are applied, and that omitting them preserves the previous behavior.
|
🚀 Deployed on https://69ace0cf57cf509481ea74b1--opengeos.netlify.app |
Summary
train_transformsandval_transformsparameters totrain_timm_segmentation_model(), matching the existing interface intrain_segmentation_model()SegmentationDatasetalready supported transforms — this change exposes that capability through the high-level APICloses #605
Test plan
train_timm_segmentation_model()accepts custom transforms and passes them to the dataset