2727
2828class SegToBinaryMaskd (MapTransform ):
2929 """Convert segmentation to binary mask using MONAI MapTransform.
30-
30+
3131 Args:
3232 keys: Keys to transform
3333 segment_id: List of segment IDs to include as foreground.
@@ -59,7 +59,7 @@ class SegToAffinityMapd(MapTransform):
5959 def __init__ (
6060 self ,
6161 keys : KeysCollection ,
62- offsets : List [str ] = [' 1-1-0' , ' 1-0-0' , ' 0-1-0' , ' 0-0-1' ],
62+ offsets : List [str ] = [" 1-1-0" , " 1-0-0" , " 0-1-0" , " 0-0-1" ],
6363 allow_missing_keys : bool = False ,
6464 ) -> None :
6565 super ().__init__ (keys , allow_missing_keys )
@@ -79,7 +79,7 @@ class SegToInstanceBoundaryMaskd(MapTransform):
7979 Args:
8080 keys: Keys to transform
8181 thickness: Thickness of the boundary (half-size of dilation struct) (default: 1)
82- do_bg_edges: Generate contour between instances and background (default: True )
82+ edge_mode: Edge detection mode - "all", "seg-all", or "seg-no-bg" (default: "seg-all" )
8383 mode: '2d' for slice-by-slice or '3d' for full 3D boundary detection (default: '3d')
8484 allow_missing_keys: Whether to allow missing keys
8585 """
@@ -88,20 +88,20 @@ def __init__(
8888 self ,
8989 keys : KeysCollection ,
9090 thickness : int = 1 ,
91- do_bg_edges : bool = True ,
92- mode : str = '3d' ,
91+ edge_mode : str = "seg-all" ,
92+ mode : str = "3d" ,
9393 allow_missing_keys : bool = False ,
9494 ) -> None :
9595 super ().__init__ (keys , allow_missing_keys )
9696 self .thickness = thickness
97- self .do_bg_edges = do_bg_edges
97+ self .edge_mode = edge_mode
9898 self .mode = mode
9999
100100 def __call__ (self , data : Dict [str , Any ]) -> Dict [str , Any ]:
101101 d = dict (data )
102102 for key in self .key_iterator (d ):
103103 if key in d :
104- d [key ] = seg_to_instance_bd (d [key ], self .thickness , self .do_bg_edges , self .mode )
104+ d [key ] = seg_to_instance_bd (d [key ], self .thickness , self .edge_mode , self .mode )
105105 return d
106106
107107
@@ -118,7 +118,7 @@ class SegToInstanceEDTd(MapTransform):
118118 def __init__ (
119119 self ,
120120 keys : KeysCollection ,
121- mode : str = '2d' ,
121+ mode : str = "2d" ,
122122 quantize : bool = False ,
123123 allow_missing_keys : bool = False ,
124124 ) -> None :
@@ -223,7 +223,7 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
223223
224224class SegToSemanticEDTd (MapTransform ):
225225 """Convert segmentation to semantic EDT using MONAI MapTransform.
226-
226+
227227 Args:
228228 keys: Keys to transform
229229 mode: EDT computation mode: '2d' or '3d' (default: '2d')
@@ -235,7 +235,7 @@ class SegToSemanticEDTd(MapTransform):
235235 def __init__ (
236236 self ,
237237 keys : KeysCollection ,
238- mode : str = '2d' ,
238+ mode : str = "2d" ,
239239 alpha_fore : float = 8.0 ,
240240 alpha_back : float = 50.0 ,
241241 allow_missing_keys : bool = False ,
@@ -249,9 +249,9 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
249249 d = dict (data )
250250 for key in self .key_iterator (d ):
251251 if key in d :
252- d [key ] = seg_to_semantic_edt (d [ key ], mode = self . mode ,
253- alpha_fore = self .alpha_fore ,
254- alpha_back = self . alpha_back )
252+ d [key ] = seg_to_semantic_edt (
253+ d [ key ], mode = self . mode , alpha_fore = self .alpha_fore , alpha_back = self . alpha_back
254+ )
255255 return d
256256
257257
@@ -261,7 +261,7 @@ class SegToFlowFieldd(MapTransform):
261261 def __init__ (
262262 self ,
263263 keys : KeysCollection ,
264- target_opt : List [str ] = ['1' ],
264+ target_opt : List [str ] = ["1" ],
265265 allow_missing_keys : bool = False ,
266266 ) -> None :
267267 super ().__init__ (keys , allow_missing_keys )
@@ -278,7 +278,7 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
278278
279279class SegToSynapticPolarityd (MapTransform ):
280280 """Convert segmentation to synaptic polarity using MONAI MapTransform.
281-
281+
282282 Args:
283283 keys: Keys to transform
284284 exclusive: If False, returns 3-channel non-exclusive masks (for BCE loss).
@@ -305,7 +305,7 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
305305
306306class SegToSmallObjectd (MapTransform ):
307307 """Convert segmentation to small object mask using MONAI MapTransform.
308-
308+
309309 Args:
310310 keys: Keys to transform
311311 threshold: Maximum voxel count for objects to be considered small (default: 100)
@@ -335,7 +335,7 @@ class ComputeBinaryRatioWeightd(MapTransform):
335335 def __init__ (
336336 self ,
337337 keys : KeysCollection ,
338- target_opt : List [str ] = ['1' ],
338+ target_opt : List [str ] = ["1" ],
339339 allow_missing_keys : bool = False ,
340340 ) -> None :
341341 super ().__init__ (keys , allow_missing_keys )
@@ -355,7 +355,7 @@ class ComputeUNet3DWeightd(MapTransform):
355355 def __init__ (
356356 self ,
357357 keys : KeysCollection ,
358- target_opt : List [str ] = ['1' , '1' , ' 5.0' , ' 0.3' ],
358+ target_opt : List [str ] = ["1" , "1" , " 5.0" , " 0.3" ],
359359 allow_missing_keys : bool = False ,
360360 ) -> None :
361361 super ().__init__ (keys , allow_missing_keys )
@@ -461,7 +461,7 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
461461
462462class EnergyQuantized (MapTransform ):
463463 """Quantize continuous energy maps using MONAI MapTransform.
464-
464+
465465 This transform converts continuous energy values to discrete quantized levels,
466466 useful for training neural networks on energy-based targets.
467467 """
@@ -491,15 +491,15 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
491491
492492class DecodeQuantized (MapTransform ):
493493 """Decode quantized energy maps back to continuous values using MONAI MapTransform.
494-
494+
495495 This transform converts quantized discrete levels back to continuous energy values,
496496 typically used for inference or evaluation.
497497 """
498498
499499 def __init__ (
500500 self ,
501501 keys : KeysCollection ,
502- mode : str = ' max' ,
502+ mode : str = " max" ,
503503 allow_missing_keys : bool = False ,
504504 ) -> None :
505505 """
@@ -509,7 +509,7 @@ def __init__(
509509 allow_missing_keys: Whether to ignore missing keys.
510510 """
511511 super ().__init__ (keys , allow_missing_keys )
512- if mode not in [' max' , ' mean' ]:
512+ if mode not in [" max" , " mean" ]:
513513 raise ValueError (f"Mode must be 'max' or 'mean', got { mode } " )
514514 self .mode = mode
515515
@@ -523,7 +523,7 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
523523
524524class SegSelectiond (MapTransform ):
525525 """Select specific segmentation indices using MONAI MapTransform.
526-
526+
527527 This transform selects only the specified label indices from a segmentation,
528528 renumbering them consecutively starting from 1.
529529 """
@@ -541,7 +541,9 @@ def __init__(
541541 allow_missing_keys: Whether to ignore missing keys.
542542 """
543543 super ().__init__ (keys , allow_missing_keys )
544- self .indices = ensure_tuple_rep (indices , 1 ) if not isinstance (indices , (list , tuple )) else indices
544+ self .indices = (
545+ ensure_tuple_rep (indices , 1 ) if not isinstance (indices , (list , tuple )) else indices
546+ )
545547
546548 def __call__ (self , data : Dict [str , Any ]) -> Dict [str , Any ]:
547549 d = dict (data )
@@ -592,12 +594,18 @@ class MultiTaskLabelTransformd(MapTransform):
592594 }
593595 _TASK_DEFAULTS : Dict [str , Dict [str , Any ]] = {
594596 "binary" : {},
595- "affinity" : {"offsets" : [' 1-1-0' , ' 1-0-0' , ' 0-1-0' , ' 0-0-1' ]},
596- "instance_boundary" : {"thickness" : 1 , "do_bg_edges " : False , "mode" : "3d" },
597+ "affinity" : {"offsets" : [" 1-1-0" , " 1-0-0" , " 0-1-0" , " 0-0-1" ]},
598+ "instance_boundary" : {"thickness" : 1 , "edge_mode " : "seg-all" , "mode" : "3d" },
597599 "instance_edt" : {"mode" : "2d" , "quantize" : False },
598- "skeleton_aware_edt" : {"bg_value" : - 1.0 , "relabel" : True , "padding" : False ,
599- "resolution" : (1.0 , 1.0 , 1.0 ), "alpha" : 0.8 ,
600- "smooth" : True , "smooth_skeleton_only" : True },
600+ "skeleton_aware_edt" : {
601+ "bg_value" : - 1.0 ,
602+ "relabel" : True ,
603+ "padding" : False ,
604+ "resolution" : (1.0 , 1.0 , 1.0 ),
605+ "alpha" : 0.8 ,
606+ "smooth" : True ,
607+ "smooth_skeleton_only" : True ,
608+ },
601609 "semantic_edt" : {"mode" : "2d" , "alpha_fore" : 8.0 , "alpha_back" : 50.0 },
602610 "polarity" : {"exclusive" : False },
603611 "small_object" : {"threshold" : 100 },
@@ -681,7 +689,7 @@ def _init_tasks(
681689
682690 def _prepare_label (self , label : Any ) -> Tuple [np .ndarray , bool ]:
683691 """Convert label to numpy without duplicating data where possible.
684-
692+
685693 MONAI transforms expect channel-first format [C, D, H, W], not batch-first [B, C, D, H, W].
686694 We should not remove the channel dimension.
687695 """
@@ -709,7 +717,7 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
709717
710718 label = d [key ]
711719 label_np , had_batch_dim = self ._prepare_label (label )
712-
720+
713721 # Remove channel dimension if it's 1 (target functions expect [D, H, W] not [1, D, H, W])
714722 if label_np .ndim == 4 and label_np .shape [0 ] == 1 :
715723 label_np = label_np [0 ]
@@ -719,7 +727,9 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
719727 result = spec ["fn" ](label_np , ** spec ["kwargs" ])
720728 if result is None :
721729 raise RuntimeError (f"Task '{ spec ['name' ]} ' returned None." )
722- result_arr = np .asarray (result , dtype = np .float32 ) # Convert to float32 (handles bool->float)
730+ result_arr = np .asarray (
731+ result , dtype = np .float32
732+ ) # Convert to float32 (handles bool->float)
723733
724734 # Ensure each output has a channel dimension [C, D, H, W]
725735 # If output is [D, H, W], expand to [1, D, H, W]
@@ -741,28 +751,30 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
741751 else :
742752 d [key ] = label
743753 for spec , result in zip (self .task_specs , outputs ):
744- out_key = spec ["output_key" ] or self .output_key_format .format (key = key , task = spec ["name" ])
754+ out_key = spec ["output_key" ] or self .output_key_format .format (
755+ key = key , task = spec ["name" ]
756+ )
745757 d [out_key ] = self ._to_tensor (result , add_batch_dim = False )
746758 return d
747759
748760
749761__all__ = [
750- ' SegToBinaryMaskd' ,
751- ' SegToAffinityMapd' ,
752- ' SegToInstanceBoundaryMaskd' ,
753- ' SegToInstanceEDTd' ,
754- ' SegToSkeletonAwareEDTd' ,
755- ' SegToSemanticEDTd' ,
756- ' SegToFlowFieldd' ,
757- ' SegToSynapticPolarityd' ,
758- ' SegToSmallObjectd' ,
759- ' ComputeBinaryRatioWeightd' ,
760- ' ComputeUNet3DWeightd' ,
761- ' SegErosiond' ,
762- ' SegDilationd' ,
763- ' SegErosionInstanced' ,
764- ' EnergyQuantized' ,
765- ' DecodeQuantized' ,
766- ' SegSelectiond' ,
767- ' MultiTaskLabelTransformd' ,
762+ " SegToBinaryMaskd" ,
763+ " SegToAffinityMapd" ,
764+ " SegToInstanceBoundaryMaskd" ,
765+ " SegToInstanceEDTd" ,
766+ " SegToSkeletonAwareEDTd" ,
767+ " SegToSemanticEDTd" ,
768+ " SegToFlowFieldd" ,
769+ " SegToSynapticPolarityd" ,
770+ " SegToSmallObjectd" ,
771+ " ComputeBinaryRatioWeightd" ,
772+ " ComputeUNet3DWeightd" ,
773+ " SegErosiond" ,
774+ " SegDilationd" ,
775+ " SegErosionInstanced" ,
776+ " EnergyQuantized" ,
777+ " DecodeQuantized" ,
778+ " SegSelectiond" ,
779+ " MultiTaskLabelTransformd" ,
768780]
0 commit comments