2929from ..instance_segmentation import get_unetr
3030from . import joint_sam_trainer as joint_trainers
3131from ..util import get_device , get_model_names , export_custom_sam_model
32- from .util import get_trainable_sam_model , ConvertToSamInputs , require_8bit
32+ from .util import get_trainable_sam_model , ConvertToSamInputs , require_8bit , get_raw_transform
3333
3434
3535FilePath = Union [str , os .PathLike ]
@@ -366,6 +366,9 @@ def _update_patch_shape(patch_shape, raw_paths, raw_key, with_channels):
366366 image_path = glob (os .path .join (path , raw_key ))[0 ]
367367 ndim = imageio .imread (image_path ).ndim
368368
369+ if not isinstance (patch_shape , tuple ):
370+ patch_shape = tuple (patch_shape )
371+
369372 if ndim == 2 :
370373 assert len (patch_shape ) == 2
371374 return patch_shape
@@ -487,6 +490,7 @@ def default_sam_dataset(
487490 with_channels = with_channels ,
488491 ndim = 2 ,
489492 is_seg_dataset = is_seg_dataset ,
493+ raw_transform = raw_transform ,
490494 ** kwargs
491495 )
492496 n_samples = max (len (loader ), 100 if is_train else 5 )
@@ -627,6 +631,10 @@ def train_sam_for_configuration(
627631
628632def _export_helper (save_root , checkpoint_name , output_path , model_type , with_segmentation_decoder , val_loader ):
629633
634+ # Whether the model is stored in the current working directory or in another location.
635+ if save_root is None :
636+ save_root = os .getcwd () # Map this to current working directory, if not specified by the user.
637+
630638 # Get the 'best' model checkpoint ready for export.
631639 best_checkpoint = os .path .join (save_root , "checkpoints" , checkpoint_name , "best.pt" )
632640 if not os .path .exists (best_checkpoint ):
@@ -734,7 +742,8 @@ def main():
734742 )
735743 parser .add_argument (
736744 "--segmentation_decoder" , type = str , default = "instances" , # TODO: in future, we can extend this to semantic seg.
737- help = "Whether to finetune Segment Anything Model with additional instance segmentation decoder."
745+ help = "Whether to finetune Segment Anything Model with additional segmentation decoder for desired targets. "
746+ "By default, it trains with the additional segmentation decoder for instance segmentation."
738747 )
739748
740749 # Optional advanced settings a user can opt to change the values for.
@@ -779,6 +788,11 @@ def main():
779788 "--batch_size" , type = int , default = 1 ,
780789 help = "The choice of batch size for training the Segment Anything Model. By default, trains on batch size 1."
781790 )
791+ parser .add_argument (
792+ "--preprocess" , type = str , default = None , choices = ("normalize_minmax" , "normalize_percentile" ),
793+ help = "Whether to normalize the raw inputs. By default, does not perform any preprocessing of input images "
794+ "Otherwise, choose from either 'normalize_percentile' or 'normalize_minmax'."
795+ )
782796
783797 args = parser .parse_args ()
784798
@@ -802,6 +816,9 @@ def main():
802816
803817 # 2. Prepare the dataloaders.
804818
819+ # If the user wants to preprocess the inputs, we allow the possibility to do so.
820+ _raw_transform = get_raw_transform (args .preprocess )
821+
805822 # Get the dataset with files for training.
806823 dataset = default_sam_dataset (
807824 raw_paths = train_images ,
@@ -810,13 +827,14 @@ def main():
810827 label_key = train_gt_key ,
811828 patch_shape = patch_shape ,
812829 with_segmentation_decoder = with_segmentation_decoder ,
830+ raw_transform = _raw_transform ,
813831 )
814832
815833 # If val images are not exclusively provided, we create a val split from the training data.
816834 if val_images is None :
817835 assert val_gt is None and val_image_key is None and val_gt_key is None
818836 # Use 10% of the dataset for validation - at least one image - for validation.
819- n_val = min (1 , int (0.1 * len (dataset )))
837+ n_val = max (1 , int (0.1 * len (dataset )))
820838 train_dataset , val_dataset = random_split (dataset , lengths = [len (dataset ) - n_val , n_val ])
821839
822840 else : # If val images provided, we create a new dataset for it.
@@ -828,6 +846,7 @@ def main():
828846 label_key = val_gt_key ,
829847 patch_shape = patch_shape ,
830848 with_segmentation_decoder = with_segmentation_decoder ,
849+ raw_transform = _raw_transform ,
831850 )
832851
833852 # Get the dataloaders from the datasets.
@@ -845,8 +864,6 @@ def main():
845864 if model_type is None : # If user does not specify the model, we use the default model corresponding to the config.
846865 model_type = CONFIGURATIONS [config ]["model_type" ]
847866
848- print (model_type , config )
849-
850867 train_sam_for_configuration (
851868 name = checkpoint_name ,
852869 configuration = config ,
0 commit comments