Skip to content

Commit a7ef623

Browse files
authored
Add support for preprocessing inputs in training CLI (#879)
Add support for preprocessing inputs in training CLI
1 parent 2c2878c commit a7ef623

File tree

3 files changed

+64
-6
lines changed

3 files changed

+64
-6
lines changed

.gitignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ examples/data/*
178178
# Torch-em stuff
179179
checkpoints/
180180
logs/
181+
*.pth
182+
*.pt
181183

182184
# And some other stuff to avoid tracking as well.
183185
gpu_jobs/
@@ -189,6 +191,10 @@ iterative_prompting_results/
189191
*.sh
190192
*.svg
191193
*.csv
194+
*.tiff
195+
*.tif
196+
*.zip
197+
*MACOSX
192198

193199
# Related to i2k workshop folders.
194200
data/

micro_sam/training/training.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from ..instance_segmentation import get_unetr
3030
from . import joint_sam_trainer as joint_trainers
3131
from ..util import get_device, get_model_names, export_custom_sam_model
32-
from .util import get_trainable_sam_model, ConvertToSamInputs, require_8bit
32+
from .util import get_trainable_sam_model, ConvertToSamInputs, require_8bit, get_raw_transform
3333

3434

3535
FilePath = Union[str, os.PathLike]
@@ -366,6 +366,9 @@ def _update_patch_shape(patch_shape, raw_paths, raw_key, with_channels):
366366
image_path = glob(os.path.join(path, raw_key))[0]
367367
ndim = imageio.imread(image_path).ndim
368368

369+
if not isinstance(patch_shape, tuple):
370+
patch_shape = tuple(patch_shape)
371+
369372
if ndim == 2:
370373
assert len(patch_shape) == 2
371374
return patch_shape
@@ -487,6 +490,7 @@ def default_sam_dataset(
487490
with_channels=with_channels,
488491
ndim=2,
489492
is_seg_dataset=is_seg_dataset,
493+
raw_transform=raw_transform,
490494
**kwargs
491495
)
492496
n_samples = max(len(loader), 100 if is_train else 5)
@@ -627,6 +631,10 @@ def train_sam_for_configuration(
627631

628632
def _export_helper(save_root, checkpoint_name, output_path, model_type, with_segmentation_decoder, val_loader):
629633

634+
# Whether the model is stored in the current working directory or in another location.
635+
if save_root is None:
636+
save_root = os.getcwd() # Map this to current working directory, if not specified by the user.
637+
630638
# Get the 'best' model checkpoint ready for export.
631639
best_checkpoint = os.path.join(save_root, "checkpoints", checkpoint_name, "best.pt")
632640
if not os.path.exists(best_checkpoint):
@@ -734,7 +742,8 @@ def main():
734742
)
735743
parser.add_argument(
736744
"--segmentation_decoder", type=str, default="instances", # TODO: in future, we can extend this to semantic seg.
737-
help="Whether to finetune Segment Anything Model with additional instance segmentation decoder."
745+
help="Whether to finetune Segment Anything Model with additional segmentation decoder for desired targets. "
746+
"By default, it trains with the additional segmentation decoder for instance segmentation."
738747
)
739748

740749
# Optional advanced settings a user can opt to change the values for.
@@ -779,6 +788,11 @@ def main():
779788
"--batch_size", type=int, default=1,
780789
help="The choice of batch size for training the Segment Anything Model. By default, trains on batch size 1."
781790
)
791+
parser.add_argument(
792+
"--preprocess", type=str, default=None, choices=("normalize_minmax", "normalize_percentile"),
793+
help="Whether to normalize the raw inputs. By default, does not perform any preprocessing of input images "
794+
"Otherwise, choose from either 'normalize_percentile' or 'normalize_minmax'."
795+
)
782796

783797
args = parser.parse_args()
784798

@@ -802,6 +816,9 @@ def main():
802816

803817
# 2. Prepare the dataloaders.
804818

819+
# If the user wants to preprocess the inputs, we allow the possibility to do so.
820+
_raw_transform = get_raw_transform(args.preprocess)
821+
805822
# Get the dataset with files for training.
806823
dataset = default_sam_dataset(
807824
raw_paths=train_images,
@@ -810,13 +827,14 @@ def main():
810827
label_key=train_gt_key,
811828
patch_shape=patch_shape,
812829
with_segmentation_decoder=with_segmentation_decoder,
830+
raw_transform=_raw_transform,
813831
)
814832

815833
# If val images are not exclusively provided, we create a val split from the training data.
816834
if val_images is None:
817835
assert val_gt is None and val_image_key is None and val_gt_key is None
818836
# Use 10% of the dataset for validation - at least one image - for validation.
819-
n_val = min(1, int(0.1 * len(dataset)))
837+
n_val = max(1, int(0.1 * len(dataset)))
820838
train_dataset, val_dataset = random_split(dataset, lengths=[len(dataset) - n_val, n_val])
821839

822840
else: # If val images provided, we create a new dataset for it.
@@ -828,6 +846,7 @@ def main():
828846
label_key=val_gt_key,
829847
patch_shape=patch_shape,
830848
with_segmentation_decoder=with_segmentation_decoder,
849+
raw_transform=_raw_transform,
831850
)
832851

833852
# Get the dataloaders from the datasets.
@@ -845,8 +864,6 @@ def main():
845864
if model_type is None: # If user does not specify the model, we use the default model corresponding to the config.
846865
model_type = CONFIGURATIONS[config]["model_type"]
847866

848-
print(model_type, config)
849-
850867
train_sam_for_configuration(
851868
name=checkpoint_name,
852869
configuration=config,

micro_sam/training/util.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
from math import ceil, floor
3-
from typing import Dict, List, Optional, Union, Tuple
3+
from functools import partial
4+
from typing import Dict, List, Optional, Union, Tuple, Callable
45

56
import numpy as np
67

@@ -39,6 +40,40 @@ def require_8bit(x):
3940
return x
4041

4142

43+
def _raw_transform(image: np.ndarray, raw_trafo: Callable) -> np.ndarray:
44+
return raw_trafo(image) * 255
45+
46+
47+
def _normalize_percentile(image: np.ndarray) -> np.ndarray:
48+
image = normalize_percentile(image) # Use 1st and 99th percentile values for min-max normalization.
49+
image = np.clip(image, 0, 1) # Clip the values to be in range [0, 1].
50+
return image
51+
52+
53+
def get_raw_transform(preprocess: Optional[str] = None) -> Optional[Callable]:
54+
"""Transformation functions to normalize inputs.
55+
56+
Args:
57+
preprocess: By default, the transformation function is set to 'None'.
58+
The user can choose from 'normalize_minmax' / 'normalize_percentile'.
59+
60+
Returns:
61+
The transformation function.
62+
"""
63+
64+
if preprocess is None: # Ensures that inputs are 8-bit.
65+
return require_8bit
66+
else:
67+
if preprocess == "normalize_minmax":
68+
raw_trafo = normalize
69+
elif preprocess == "normalize_percentile":
70+
raw_trafo = _normalize_percentile
71+
else:
72+
raise ValueError(f"'{preprocess}' is not a supported preprocessing.")
73+
74+
return partial(_raw_transform, raw_trafo=raw_trafo)
75+
76+
4277
def get_trainable_sam_model(
4378
model_type: str = _DEFAULT_MODEL,
4479
device: Optional[Union[str, torch.device]] = None,

0 commit comments

Comments
 (0)