Skip to content

Commit 2e03784

Browse files
authored
Add train_transforms and val_transforms to train_timm_segmentation_model (#606)
Expose the existing transform support in SegmentationDataset through the high-level train_timm_segmentation_model() API, matching the interface already available in train_segmentation_model(). Closes #605
1 parent 4b3a992 commit 2e03784

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

geoai/timm_segment.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,8 @@ def train_timm_segmentation_model(
642642
device: Optional[str] = None,
643643
use_timm_model: bool = False,
644644
timm_model_name: Optional[str] = None,
645+
train_transforms: Optional[Callable] = None,
646+
val_transforms: Optional[Callable] = None,
645647
**kwargs: Any,
646648
) -> torch.nn.Module:
647649
"""
@@ -684,6 +686,15 @@ def train_timm_segmentation_model(
684686
device (str, optional): Device to use. Auto-detected if None.
685687
use_timm_model (bool): Load complete segmentation model from timm/HF Hub.
686688
timm_model_name (str, optional): Model name from HF Hub (e.g., 'hf-hub:nvidia/mit-b0').
689+
train_transforms (callable, optional): Custom transforms for training data.
690+
Should be a callable that accepts (image, mask) tensors and returns
691+
transformed (image, mask). Both image and mask should be torch.Tensor
692+
objects. The image tensor is in CHW format (channels, height, width),
693+
and the mask tensor in HW format (height, width). If None, no
694+
augmentation is applied. Defaults to None.
695+
val_transforms (callable, optional): Custom transforms for validation data.
696+
Same signature as train_transforms. If None, no augmentation is
697+
applied. Defaults to None.
687698
**kwargs: Additional arguments for training.
688699
689700
Returns:
@@ -757,12 +768,14 @@ def train_timm_segmentation_model(
757768
image_paths=train_images,
758769
mask_paths=train_labels,
759770
num_channels=num_channels,
771+
transform=train_transforms,
760772
)
761773

762774
val_dataset = SegmentationDataset(
763775
image_paths=val_images,
764776
mask_paths=val_labels,
765777
num_channels=num_channels,
778+
transform=val_transforms,
766779
)
767780

768781
# Train model

0 commit comments

Comments
 (0)