Skip to content

Commit da6c847

Browse files
Finsih micro-sam training iteration
1 parent 2fd0344 commit da6c847

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

scripts/training/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@ They will load the **image data** according to the following rules:
1414
The training script will save the trained model in `checkpoints/cochlea_distance_unet_<CURRENT_DATE>`, e.g. `checkpoints/cochlea_distance_unet_20250115`.
1515
For further options for the scripts run `python check_training_data.py -h` / `python train_distance_unet.py -h`.
1616

17-
The script `train_micro_sam.py` works similar to the U-Net training script. It saves the finetuned model for annotation with `micro_sam` to `checkpoints/`.
17+
The script `train_micro_sam.py` works similar to the U-Net training script. It saves the finetuned model for annotation with `micro_sam` to `checkpoints/cochlea_micro_sam_<CURRENT_DATE>`.

scripts/training/train_micro_sam.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from datetime import datetime
33

44
import numpy as np
5-
import torch_em
65
from micro_sam.training import default_sam_loader, train_sam
76
from train_distance_unet import get_image_and_label_paths, select_paths
87

@@ -36,35 +35,35 @@ def main():
3635

3736
root = args.root
3837
run_name = datetime.now().strftime("%Y%m%d") if args.name is None else args.name
39-
name = f"cochlea_distance_unet_{run_name}"
38+
name = f"cochlea_micro_sam_{run_name}"
4039
n_objects_per_batch = args.n_objects_per_batch
4140

4241
image_paths, label_paths = get_image_and_label_paths(root)
4342
train_image_paths, train_label_paths = select_paths(image_paths, label_paths, split="train", filter_empty=True)
4443
val_image_paths, val_label_paths = select_paths(image_paths, label_paths, split="val", filter_empty=True)
4544

4645
patch_shape = (1, 256, 256)
47-
sampler = torch_em.data.sampler.MinInstanceSampler(2, min_size=10)
4846
max_sampling_attempts = 2500
4947

5048
train_loader = default_sam_loader(
5149
raw_paths=train_image_paths, raw_key=None, label_paths=train_label_paths, label_key=None,
5250
patch_shape=patch_shape, with_segmentation_decoder=True,
53-
raw_transform=raw_transform, sampler=sampler, min_size=10,
51+
raw_transform=raw_transform,
5452
num_workers=6, batch_size=1, is_train=True,
5553
max_sampling_attempts=max_sampling_attempts,
5654
)
5755
val_loader = default_sam_loader(
5856
raw_paths=val_image_paths, raw_key=None, label_paths=val_label_paths, label_key=None,
5957
patch_shape=patch_shape, with_segmentation_decoder=True,
60-
raw_transform=raw_transform, sampler=sampler, min_size=10,
58+
raw_transform=raw_transform,
6159
num_workers=6, batch_size=1, is_train=False,
6260
max_sampling_attempts=max_sampling_attempts,
6361
)
6462

6563
train_sam(
6664
name=name, model_type="vit_b_lm", train_loader=train_loader, val_loader=val_loader,
6765
n_epochs=50, n_objects_per_batch=n_objects_per_batch,
66+
save_root=".",
6867
)
6968

7069

0 commit comments

Comments
 (0)