@@ -33,8 +33,7 @@ def __init__(
3333 input_keys : Tuple [str , ...],
3434 transforms : A .Compose ,
3535 inst_transforms : ApplyEach ,
36- drop_keys : Tuple [str , ...] = None ,
37- output_device : str = "cuda" ,
36+ map_out_keys : Dict [str , str ] = None ,
3837 ) -> None :
3938 """HDF5 train dataset for cell/panoptic segmentation models.
4039
@@ -47,10 +46,12 @@ def __init__(
4746 Albumentations compose object for image and mask transforms.
4847 inst_transforms (ApplyEach):
4948 ApplyEach object for instance transforms.
50- drop_keys (Tuple[str, ...], default=None):
51- Tuple of keys to be dropped from the output dictionary.
52- output_device (str):
53- Output device for the image and masks.
49+ map_out_keys (Dict[str, str], default=None):
50+ A dictionary to map the default output keys to new output keys. .
51+ Useful if you want to match the output keys with model output keys.
52+ e.g. {"inst": "decoder1-inst", "inst-cellpose": decoder2-cellpose}.
53+ The default output keys are any of 'image', 'inst', 'type', 'cyto_inst',
54+ 'cyto_type', 'sem' & inst-{transform.name}, cyto_inst-{transform.name}.
5455
5556 Raises:
5657 ModuleNotFoundError: If albumentations or tables is not installed.
@@ -87,14 +88,13 @@ def __init__(
8788 self .mask_keys = [k for k in input_keys if k != "image" ]
8889 self .inst_in_keys = [k for k in input_keys if "inst" in k ]
8990 self .inst_out_keys = [
90- f"{ name } _ { key } "
91+ f"{ key } - { name } "
9192 for name in inst_transforms .names
9293 for key in self .inst_in_keys
9394 ]
9495 self .transforms = transforms
9596 self .inst_transforms = inst_transforms
96- self .output_device = output_device
97- self .drop_keys = drop_keys
97+ self .map_out_keys = map_out_keys
9898
9999 with tb .open_file (path , "r" ) as h5 :
100100 self .n_items = len (h5 .root ["fname" ][:])
@@ -129,16 +129,27 @@ def __getitem__(self, ix: int) -> Dict[str, np.ndarray]:
129129 masks = to_tensor (tr ["masks" ][0 ])
130130 masks = torch .split (masks , mask_chls , dim = 0 )
131131
132- integer_masks = {k : masks [i ] for i , k in enumerate (self .mask_keys )}
132+ integer_masks = {
133+ n : masks [i ].squeeze ().long ()
134+ for i , n in enumerate (self .mask_keys )
135+ # n: masks[i].squeeze()
136+ # for i, n in enumerate(self.mask_keys)
137+ }
133138 inst_transformed_masks = {
134- f"{ n } " : masks [len (integer_masks ) + i ]
139+ # n: masks[len(integer_masks) + i]
140+ # for i, n in enumerate(self.inst_out_keys)
141+ n : masks [len (integer_masks ) + i ].float ()
135142 for i , n in enumerate (self .inst_out_keys )
136143 }
137144
138- out = {"image" : image , ** integer_masks , ** inst_transformed_masks }
145+ # out = {"image": image.float(), **integer_masks, **inst_transformed_masks}
146+ out = {"image" : image .float (), ** integer_masks , ** inst_transformed_masks }
139147
140- if self .drop_keys is not None :
141- for key in self .drop_keys :
142- del out [key ]
148+ if self .map_out_keys is not None :
149+ new_out = {}
150+ for in_key , out_key in self .map_out_keys .items ():
151+ if in_key in out :
152+ new_out [out_key ] = out .pop (in_key )
153+ out = new_out
143154
144155 return out
0 commit comments