Skip to content

Add train_transforms and val_transforms to train_timm_segmentation_model#606

Merged
giswqs merged 1 commit intomainfrom
add-transforms-to-timm-segmentation
Mar 8, 2026
Merged

Add train_transforms and val_transforms to train_timm_segmentation_model#606
giswqs merged 1 commit intomainfrom
add-transforms-to-timm-segmentation

Conversation

@giswqs
Copy link
Member

@giswqs giswqs commented Mar 8, 2026

Summary

  • Adds train_transforms and val_transforms parameters to train_timm_segmentation_model(), matching the existing interface in train_segmentation_model()
  • The underlying SegmentationDataset already supported transforms — this change exposes that capability through the high-level API

Closes #605

Test plan

  • Verify train_timm_segmentation_model() accepts custom transforms and passes them to the dataset
  • Verify default behavior (no transforms) is unchanged when parameters are omitted

Expose the existing transform support in SegmentationDataset through the
high-level train_timm_segmentation_model() API, matching the interface
already available in train_segmentation_model().

Closes #605
Copilot AI review requested due to automatic review settings March 8, 2026 02:33
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Exposes dataset-level augmentation hooks in train_timm_segmentation_model() by adding train_transforms and val_transforms parameters and wiring them through to the underlying SegmentationDataset, aligning the high-level API more closely with train_segmentation_model().

Changes:

  • Add train_transforms and val_transforms parameters to train_timm_segmentation_model().
  • Document the expected transform callable signature and tensor formats.
  • Pass the provided transforms into SegmentationDataset(transform=...) for train/val splits.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 767 to 772
train_dataset = SegmentationDataset(
image_paths=train_images,
mask_paths=train_labels,
num_channels=num_channels,
transform=train_transforms,
)
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change introduces new public parameters (train_transforms/val_transforms) and forwards them into SegmentationDataset, but there are no automated tests that assert the callables are actually passed through and invoked. Add a unit test (e.g., patch SegmentationDataset and/or provide a transform that mutates the tensors) to verify both train and val transforms are applied, and that omitting them preserves the previous behavior.

Copilot uses AI. Check for mistakes.
@github-actions
Copy link

github-actions bot commented Mar 8, 2026

@github-actions github-actions bot temporarily deployed to pull request March 8, 2026 02:37 Inactive
@giswqs giswqs merged commit 2e03784 into main Mar 8, 2026
14 checks passed
@giswqs giswqs deleted the add-transforms-to-timm-segmentation branch March 8, 2026 02:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants