Skip to content

Commit ea2f581

Browse files
Implement support for anisotropic training
1 parent c0d2820 commit ea2f581

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

flamingo_tools/training/util.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def get_supervised_loader(
2929
label_key: Optional[str] = None,
3030
n_samples: Optional[int] = None,
3131
raw_transform: Optional[callable] = None,
32+
anisotropy: Optional[float] = None,
3233
) -> DataLoader:
3334
"""Get a data loader for a supervised segmentation task.
3435
@@ -41,14 +42,16 @@ def get_supervised_loader(
4142
image_key: Internal path for the label masks. This is only required for hdf5/zarr/n5 data.
4243
n_samples: The number of samples to use for training.
4344
raw_transform: Optional transformation for the raw data.
45+
anisotropy: The anisotropy factor for distance target computation.
4446
4547
Returns:
4648
The data loader.
4749
"""
4850
assert len(image_paths) == len(label_paths)
4951
assert len(image_paths) > 0
52+
sampling = None if anisotropy is None else (anisotropy, 1.0, 1.0)
5053
label_transform = torch_em.transform.label.PerObjectDistanceTransform(
51-
distances=True, boundary_distances=True, foreground=True,
54+
distances=True, boundary_distances=True, foreground=True, sampling=sampling,
5255
)
5356
sampler = torch_em.data.sampler.MinInstanceSampler(p_reject=0.8)
5457
loader = torch_em.default_segmentation_loader(

scripts/training/train_distance_unet.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def select_paths(image_paths, label_paths, split, filter_empty, random_split=Tru
8080
return image_paths, label_paths
8181

8282

83-
def get_loader(root, split, patch_shape, batch_size, filter_empty, separate_folders):
83+
def get_loader(root, split, patch_shape, batch_size, filter_empty, separate_folders, anisotropy):
8484
if separate_folders:
8585
image_paths, label_paths = get_image_and_label_paths_sep_folders(root)
8686
else:
@@ -96,7 +96,9 @@ def get_loader(root, split, patch_shape, batch_size, filter_empty, separate_fold
9696
n_samples = 16 * batch_size
9797

9898
return (
99-
get_supervised_loader(this_image_paths, this_label_paths, patch_shape, batch_size, n_samples=n_samples),
99+
get_supervised_loader(
100+
this_image_paths, this_label_paths, patch_shape, batch_size, n_samples=n_samples, anisotropy=anisotropy
101+
),
100102
this_image_paths,
101103
this_label_paths
102104
)
@@ -124,6 +126,10 @@ def main():
124126
parser.add_argument(
125127
"--name", help="Optional name for the model to be trained. If not given the current date is used."
126128
)
129+
parser.add_argument(
130+
"--anisotropy", help="Anisotropy factor of the Z-Axis (Depth). Will be used to scale distance targets.",
131+
type=float,
132+
)
127133
parser.add_argument("--separate_folders", action="store_true")
128134
args = parser.parse_args()
129135
root = args.root
@@ -141,10 +147,12 @@ def main():
141147

142148
# Create the training loader with train and val set.
143149
train_loader, train_images, train_labels = get_loader(
144-
root, "train", patch_shape, batch_size, filter_empty=filter_empty, separate_folders=args.separate_folders
150+
root, "train", patch_shape, batch_size, filter_empty=filter_empty, separate_folders=args.separate_folders,
151+
anisotropy=args.anisotropy,
145152
)
146153
val_loader, val_images, val_labels = get_loader(
147-
root, "val", patch_shape, batch_size, filter_empty=filter_empty, separate_folders=args.separate_folders
154+
root, "val", patch_shape, batch_size, filter_empty=filter_empty, separate_folders=args.separate_folders,
155+
anisotropy=args.anisotropy,
148156
)
149157

150158
if check_loaders:

0 commit comments

Comments
 (0)