Skip to content

Commit 6f22309

Browse files
committed
Fix instance estimations from labels
1 parent 2e755cc commit 6f22309

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

finetuning/livecell_finetuning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ def get_dataloaders(patch_shape, data_path, cell_type=None):
2323
label_transform = torch_em.transform.label.label_consecutive # to ensure consecutive IDs
2424
train_loader = get_livecell_loader(path=data_path, patch_shape=patch_shape, split="train", batch_size=2,
2525
num_workers=8, cell_types=cell_type, download=True,
26-
label_transform=label_transform)
26+
label_transform=label_transform, shuffle=True)
2727
val_loader = get_livecell_loader(path=data_path, patch_shape=patch_shape, split="val", batch_size=1,
2828
num_workers=8, cell_types=cell_type, download=True,
29-
label_transform=label_transform)
29+
label_transform=label_transform, shuffle=True)
3030
return train_loader, val_loader
3131

3232

micro_sam/training/sam_trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ def _get_updated_points_per_mask_per_subiter(self, masks, sampled_binary_y, batc
258258

259259
def _update_samples_for_gt_instances(self, y, n_samples):
260260
num_instances_gt = torch.amax(y, dim=(1, 2, 3))
261+
num_instances_gt = num_instances_gt.numpy()
261262
n_samples = min(num_instances_gt) if n_samples > min(num_instances_gt) else n_samples
262263
return n_samples
263264

0 commit comments

Comments
 (0)