Skip to content

Commit 608c61e

Browse files
committed
fix: update train dataset classes with out_map_keys param
1 parent ceaf4bd commit 608c61e

File tree

2 files changed

+48
-29
lines changed

2 files changed

+48
-29
lines changed

cellseg_models_pytorch/torch_datasets/folder_dataset_train.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@ def __init__(
2929
mask_keys: Tuple[str, ...],
3030
transforms: A.Compose,
3131
inst_transforms: ApplyEach,
32-
drop_keys: Tuple[str, ...] = None,
33-
output_device: str = "cuda",
32+
map_out_keys: Dict[str, str] = None,
3433
) -> None:
3534
"""Folder train dataset for cell/panoptic segmentation models.
3635
@@ -45,10 +44,12 @@ def __init__(
4544
Albumentations compose object for image and mask transforms.
4645
inst_transforms (ApplyEach):
4746
ApplyEach object for instance transforms.
48-
drop_keys (Tuple[str, ...], default=None):
49-
Tuple of keys to be dropped from the output dictionary.
50-
output_device (str):
51-
Output device for the image and masks.
47+
map_out_keys (Dict[str, str], default=None):
48+
A dictionary to map the default output keys to new output keys. .
49+
Useful if you want to match the output keys with model output keys.
50+
e.g. {"inst": "decoder1-inst", "inst-cellpose": decoder2-cellpose}.
51+
The default output keys are any of 'image', 'inst', 'type', 'cyto_inst',
52+
'cyto_type', 'sem' & inst-{transform.name}, cyto_inst-{transform.name}.
5253
5354
Raises:
5455
ModuleNotFoundError: If albumentations or tables is not installed.
@@ -84,8 +85,7 @@ def __init__(
8485
]
8586
self.transforms = transforms
8687
self.inst_transforms = inst_transforms
87-
self.output_device = output_device
88-
self.drop_keys = drop_keys
88+
self.map_out_keys = map_out_keys
8989

9090
def __len__(self) -> int:
9191
"""Return the number of items in the db."""
@@ -122,16 +122,24 @@ def __getitem__(self, ix: int) -> Dict[str, np.ndarray]:
122122
masks = to_tensor(tr["masks"][0])
123123
masks = torch.split(masks, mask_chls, dim=0)
124124

125-
integer_masks = {k: masks[i] for i, k in enumerate(self.mask_keys)}
125+
integer_masks = {
126+
# k: masks[i].squeeze().long() for i, k in enumerate(self.mask_keys)
127+
k: masks[i].squeeze()
128+
for i, k in enumerate(self.mask_keys)
129+
}
126130
inst_transformed_masks = {
127-
f"{n}": masks[len(integer_masks) + i]
131+
# f"{n}": masks[len(integer_masks) + i].float()
132+
n: masks[len(integer_masks) + i]
128133
for i, n in enumerate(self.inst_out_keys)
129134
}
130135

131-
out = {"image": image, **integer_masks, **inst_transformed_masks}
136+
out = {"image": image.float(), **integer_masks, **inst_transformed_masks}
132137

133-
if self.drop_keys is not None:
134-
for key in self.drop_keys:
135-
del out[key]
138+
if self.map_out_keys is not None:
139+
new_out = {}
140+
for in_key, out_key in self.map_out_keys.items():
141+
if in_key in out:
142+
new_out[out_key] = out.pop(in_key)
143+
out = new_out
136144

137145
return out

cellseg_models_pytorch/torch_datasets/hdf5_dataset_train.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@ def __init__(
3333
input_keys: Tuple[str, ...],
3434
transforms: A.Compose,
3535
inst_transforms: ApplyEach,
36-
drop_keys: Tuple[str, ...] = None,
37-
output_device: str = "cuda",
36+
map_out_keys: Dict[str, str] = None,
3837
) -> None:
3938
"""HDF5 train dataset for cell/panoptic segmentation models.
4039
@@ -47,10 +46,12 @@ def __init__(
4746
Albumentations compose object for image and mask transforms.
4847
inst_transforms (ApplyEach):
4948
ApplyEach object for instance transforms.
50-
drop_keys (Tuple[str, ...], default=None):
51-
Tuple of keys to be dropped from the output dictionary.
52-
output_device (str):
53-
Output device for the image and masks.
49+
map_out_keys (Dict[str, str], default=None):
50+
A dictionary to map the default output keys to new output keys. .
51+
Useful if you want to match the output keys with model output keys.
52+
e.g. {"inst": "decoder1-inst", "inst-cellpose": decoder2-cellpose}.
53+
The default output keys are any of 'image', 'inst', 'type', 'cyto_inst',
54+
'cyto_type', 'sem' & inst-{transform.name}, cyto_inst-{transform.name}.
5455
5556
Raises:
5657
ModuleNotFoundError: If albumentations or tables is not installed.
@@ -87,14 +88,13 @@ def __init__(
8788
self.mask_keys = [k for k in input_keys if k != "image"]
8889
self.inst_in_keys = [k for k in input_keys if "inst" in k]
8990
self.inst_out_keys = [
90-
f"{name}_{key}"
91+
f"{key}-{name}"
9192
for name in inst_transforms.names
9293
for key in self.inst_in_keys
9394
]
9495
self.transforms = transforms
9596
self.inst_transforms = inst_transforms
96-
self.output_device = output_device
97-
self.drop_keys = drop_keys
97+
self.map_out_keys = map_out_keys
9898

9999
with tb.open_file(path, "r") as h5:
100100
self.n_items = len(h5.root["fname"][:])
@@ -129,16 +129,27 @@ def __getitem__(self, ix: int) -> Dict[str, np.ndarray]:
129129
masks = to_tensor(tr["masks"][0])
130130
masks = torch.split(masks, mask_chls, dim=0)
131131

132-
integer_masks = {k: masks[i] for i, k in enumerate(self.mask_keys)}
132+
integer_masks = {
133+
n: masks[i].squeeze().long()
134+
for i, n in enumerate(self.mask_keys)
135+
# n: masks[i].squeeze()
136+
# for i, n in enumerate(self.mask_keys)
137+
}
133138
inst_transformed_masks = {
134-
f"{n}": masks[len(integer_masks) + i]
139+
# n: masks[len(integer_masks) + i]
140+
# for i, n in enumerate(self.inst_out_keys)
141+
n: masks[len(integer_masks) + i].float()
135142
for i, n in enumerate(self.inst_out_keys)
136143
}
137144

138-
out = {"image": image, **integer_masks, **inst_transformed_masks}
145+
# out = {"image": image.float(), **integer_masks, **inst_transformed_masks}
146+
out = {"image": image.float(), **integer_masks, **inst_transformed_masks}
139147

140-
if self.drop_keys is not None:
141-
for key in self.drop_keys:
142-
del out[key]
148+
if self.map_out_keys is not None:
149+
new_out = {}
150+
for in_key, out_key in self.map_out_keys.items():
151+
if in_key in out:
152+
new_out[out_key] = out.pop(in_key)
153+
out = new_out
143154

144155
return out

0 commit comments

Comments
 (0)