Skip to content

Commit 1fb1e23

Browse files
author
Donglai Wei
committed
fix multiple folder saving issue
1 parent ae09e88 commit 1fb1e23

File tree

13 files changed

+1230
-563
lines changed

13 files changed

+1230
-563
lines changed

connectomics/config/hydra_config.py

Lines changed: 446 additions & 70 deletions
Large diffs are not rendered by default.

connectomics/data/process/monai_transforms.py

Lines changed: 63 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
class SegToBinaryMaskd(MapTransform):
2929
"""Convert segmentation to binary mask using MONAI MapTransform.
30-
30+
3131
Args:
3232
keys: Keys to transform
3333
segment_id: List of segment IDs to include as foreground.
@@ -59,7 +59,7 @@ class SegToAffinityMapd(MapTransform):
5959
def __init__(
6060
self,
6161
keys: KeysCollection,
62-
offsets: List[str] = ['1-1-0', '1-0-0', '0-1-0', '0-0-1'],
62+
offsets: List[str] = ["1-1-0", "1-0-0", "0-1-0", "0-0-1"],
6363
allow_missing_keys: bool = False,
6464
) -> None:
6565
super().__init__(keys, allow_missing_keys)
@@ -79,7 +79,7 @@ class SegToInstanceBoundaryMaskd(MapTransform):
7979
Args:
8080
keys: Keys to transform
8181
thickness: Thickness of the boundary (half-size of dilation struct) (default: 1)
82-
do_bg_edges: Generate contour between instances and background (default: True)
82+
edge_mode: Edge detection mode - "all", "seg-all", or "seg-no-bg" (default: "seg-all")
8383
mode: '2d' for slice-by-slice or '3d' for full 3D boundary detection (default: '3d')
8484
allow_missing_keys: Whether to allow missing keys
8585
"""
@@ -88,20 +88,20 @@ def __init__(
8888
self,
8989
keys: KeysCollection,
9090
thickness: int = 1,
91-
do_bg_edges: bool = True,
92-
mode: str = '3d',
91+
edge_mode: str = "seg-all",
92+
mode: str = "3d",
9393
allow_missing_keys: bool = False,
9494
) -> None:
9595
super().__init__(keys, allow_missing_keys)
9696
self.thickness = thickness
97-
self.do_bg_edges = do_bg_edges
97+
self.edge_mode = edge_mode
9898
self.mode = mode
9999

100100
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
101101
d = dict(data)
102102
for key in self.key_iterator(d):
103103
if key in d:
104-
d[key] = seg_to_instance_bd(d[key], self.thickness, self.do_bg_edges, self.mode)
104+
d[key] = seg_to_instance_bd(d[key], self.thickness, self.edge_mode, self.mode)
105105
return d
106106

107107

@@ -118,7 +118,7 @@ class SegToInstanceEDTd(MapTransform):
118118
def __init__(
119119
self,
120120
keys: KeysCollection,
121-
mode: str = '2d',
121+
mode: str = "2d",
122122
quantize: bool = False,
123123
allow_missing_keys: bool = False,
124124
) -> None:
@@ -223,7 +223,7 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
223223

224224
class SegToSemanticEDTd(MapTransform):
225225
"""Convert segmentation to semantic EDT using MONAI MapTransform.
226-
226+
227227
Args:
228228
keys: Keys to transform
229229
mode: EDT computation mode: '2d' or '3d' (default: '2d')
@@ -235,7 +235,7 @@ class SegToSemanticEDTd(MapTransform):
235235
def __init__(
236236
self,
237237
keys: KeysCollection,
238-
mode: str = '2d',
238+
mode: str = "2d",
239239
alpha_fore: float = 8.0,
240240
alpha_back: float = 50.0,
241241
allow_missing_keys: bool = False,
@@ -249,9 +249,9 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
249249
d = dict(data)
250250
for key in self.key_iterator(d):
251251
if key in d:
252-
d[key] = seg_to_semantic_edt(d[key], mode=self.mode,
253-
alpha_fore=self.alpha_fore,
254-
alpha_back=self.alpha_back)
252+
d[key] = seg_to_semantic_edt(
253+
d[key], mode=self.mode, alpha_fore=self.alpha_fore, alpha_back=self.alpha_back
254+
)
255255
return d
256256

257257

@@ -261,7 +261,7 @@ class SegToFlowFieldd(MapTransform):
261261
def __init__(
262262
self,
263263
keys: KeysCollection,
264-
target_opt: List[str] = ['1'],
264+
target_opt: List[str] = ["1"],
265265
allow_missing_keys: bool = False,
266266
) -> None:
267267
super().__init__(keys, allow_missing_keys)
@@ -278,7 +278,7 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
278278

279279
class SegToSynapticPolarityd(MapTransform):
280280
"""Convert segmentation to synaptic polarity using MONAI MapTransform.
281-
281+
282282
Args:
283283
keys: Keys to transform
284284
exclusive: If False, returns 3-channel non-exclusive masks (for BCE loss).
@@ -305,7 +305,7 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
305305

306306
class SegToSmallObjectd(MapTransform):
307307
"""Convert segmentation to small object mask using MONAI MapTransform.
308-
308+
309309
Args:
310310
keys: Keys to transform
311311
threshold: Maximum voxel count for objects to be considered small (default: 100)
@@ -335,7 +335,7 @@ class ComputeBinaryRatioWeightd(MapTransform):
335335
def __init__(
336336
self,
337337
keys: KeysCollection,
338-
target_opt: List[str] = ['1'],
338+
target_opt: List[str] = ["1"],
339339
allow_missing_keys: bool = False,
340340
) -> None:
341341
super().__init__(keys, allow_missing_keys)
@@ -355,7 +355,7 @@ class ComputeUNet3DWeightd(MapTransform):
355355
def __init__(
356356
self,
357357
keys: KeysCollection,
358-
target_opt: List[str] = ['1', '1', '5.0', '0.3'],
358+
target_opt: List[str] = ["1", "1", "5.0", "0.3"],
359359
allow_missing_keys: bool = False,
360360
) -> None:
361361
super().__init__(keys, allow_missing_keys)
@@ -461,7 +461,7 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
461461

462462
class EnergyQuantized(MapTransform):
463463
"""Quantize continuous energy maps using MONAI MapTransform.
464-
464+
465465
This transform converts continuous energy values to discrete quantized levels,
466466
useful for training neural networks on energy-based targets.
467467
"""
@@ -491,15 +491,15 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
491491

492492
class DecodeQuantized(MapTransform):
493493
"""Decode quantized energy maps back to continuous values using MONAI MapTransform.
494-
494+
495495
This transform converts quantized discrete levels back to continuous energy values,
496496
typically used for inference or evaluation.
497497
"""
498498

499499
def __init__(
500500
self,
501501
keys: KeysCollection,
502-
mode: str = 'max',
502+
mode: str = "max",
503503
allow_missing_keys: bool = False,
504504
) -> None:
505505
"""
@@ -509,7 +509,7 @@ def __init__(
509509
allow_missing_keys: Whether to ignore missing keys.
510510
"""
511511
super().__init__(keys, allow_missing_keys)
512-
if mode not in ['max', 'mean']:
512+
if mode not in ["max", "mean"]:
513513
raise ValueError(f"Mode must be 'max' or 'mean', got {mode}")
514514
self.mode = mode
515515

@@ -523,7 +523,7 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
523523

524524
class SegSelectiond(MapTransform):
525525
"""Select specific segmentation indices using MONAI MapTransform.
526-
526+
527527
This transform selects only the specified label indices from a segmentation,
528528
renumbering them consecutively starting from 1.
529529
"""
@@ -541,7 +541,9 @@ def __init__(
541541
allow_missing_keys: Whether to ignore missing keys.
542542
"""
543543
super().__init__(keys, allow_missing_keys)
544-
self.indices = ensure_tuple_rep(indices, 1) if not isinstance(indices, (list, tuple)) else indices
544+
self.indices = (
545+
ensure_tuple_rep(indices, 1) if not isinstance(indices, (list, tuple)) else indices
546+
)
545547

546548
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
547549
d = dict(data)
@@ -592,12 +594,18 @@ class MultiTaskLabelTransformd(MapTransform):
592594
}
593595
_TASK_DEFAULTS: Dict[str, Dict[str, Any]] = {
594596
"binary": {},
595-
"affinity": {"offsets": ['1-1-0', '1-0-0', '0-1-0', '0-0-1']},
596-
"instance_boundary": {"thickness": 1, "do_bg_edges": False, "mode": "3d"},
597+
"affinity": {"offsets": ["1-1-0", "1-0-0", "0-1-0", "0-0-1"]},
598+
"instance_boundary": {"thickness": 1, "edge_mode": "seg-all", "mode": "3d"},
597599
"instance_edt": {"mode": "2d", "quantize": False},
598-
"skeleton_aware_edt": {"bg_value": -1.0, "relabel": True, "padding": False,
599-
"resolution": (1.0, 1.0, 1.0), "alpha": 0.8,
600-
"smooth": True, "smooth_skeleton_only": True},
600+
"skeleton_aware_edt": {
601+
"bg_value": -1.0,
602+
"relabel": True,
603+
"padding": False,
604+
"resolution": (1.0, 1.0, 1.0),
605+
"alpha": 0.8,
606+
"smooth": True,
607+
"smooth_skeleton_only": True,
608+
},
601609
"semantic_edt": {"mode": "2d", "alpha_fore": 8.0, "alpha_back": 50.0},
602610
"polarity": {"exclusive": False},
603611
"small_object": {"threshold": 100},
@@ -681,7 +689,7 @@ def _init_tasks(
681689

682690
def _prepare_label(self, label: Any) -> Tuple[np.ndarray, bool]:
683691
"""Convert label to numpy without duplicating data where possible.
684-
692+
685693
MONAI transforms expect channel-first format [C, D, H, W], not batch-first [B, C, D, H, W].
686694
We should not remove the channel dimension.
687695
"""
@@ -709,7 +717,7 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
709717

710718
label = d[key]
711719
label_np, had_batch_dim = self._prepare_label(label)
712-
720+
713721
# Remove channel dimension if it's 1 (target functions expect [D, H, W] not [1, D, H, W])
714722
if label_np.ndim == 4 and label_np.shape[0] == 1:
715723
label_np = label_np[0]
@@ -719,7 +727,9 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
719727
result = spec["fn"](label_np, **spec["kwargs"])
720728
if result is None:
721729
raise RuntimeError(f"Task '{spec['name']}' returned None.")
722-
result_arr = np.asarray(result, dtype=np.float32) # Convert to float32 (handles bool->float)
730+
result_arr = np.asarray(
731+
result, dtype=np.float32
732+
) # Convert to float32 (handles bool->float)
723733

724734
# Ensure each output has a channel dimension [C, D, H, W]
725735
# If output is [D, H, W], expand to [1, D, H, W]
@@ -741,28 +751,30 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
741751
else:
742752
d[key] = label
743753
for spec, result in zip(self.task_specs, outputs):
744-
out_key = spec["output_key"] or self.output_key_format.format(key=key, task=spec["name"])
754+
out_key = spec["output_key"] or self.output_key_format.format(
755+
key=key, task=spec["name"]
756+
)
745757
d[out_key] = self._to_tensor(result, add_batch_dim=False)
746758
return d
747759

748760

749761
__all__ = [
750-
'SegToBinaryMaskd',
751-
'SegToAffinityMapd',
752-
'SegToInstanceBoundaryMaskd',
753-
'SegToInstanceEDTd',
754-
'SegToSkeletonAwareEDTd',
755-
'SegToSemanticEDTd',
756-
'SegToFlowFieldd',
757-
'SegToSynapticPolarityd',
758-
'SegToSmallObjectd',
759-
'ComputeBinaryRatioWeightd',
760-
'ComputeUNet3DWeightd',
761-
'SegErosiond',
762-
'SegDilationd',
763-
'SegErosionInstanced',
764-
'EnergyQuantized',
765-
'DecodeQuantized',
766-
'SegSelectiond',
767-
'MultiTaskLabelTransformd',
762+
"SegToBinaryMaskd",
763+
"SegToAffinityMapd",
764+
"SegToInstanceBoundaryMaskd",
765+
"SegToInstanceEDTd",
766+
"SegToSkeletonAwareEDTd",
767+
"SegToSemanticEDTd",
768+
"SegToFlowFieldd",
769+
"SegToSynapticPolarityd",
770+
"SegToSmallObjectd",
771+
"ComputeBinaryRatioWeightd",
772+
"ComputeUNet3DWeightd",
773+
"SegErosiond",
774+
"SegDilationd",
775+
"SegErosionInstanced",
776+
"EnergyQuantized",
777+
"DecodeQuantized",
778+
"SegSelectiond",
779+
"MultiTaskLabelTransformd",
768780
]

0 commit comments

Comments
 (0)