Skip to content

Commit ded4607

Browse files
committed
Update BiaPy-core functions
1 parent 442aa81 commit ded4607

File tree

2 files changed

+62
-15
lines changed

2 files changed

+62
-15
lines changed

biapy/biapy_check_configuration.py

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
1-
## Copied from BiaPy commit: 3db4edcf634b3726484d650bc03058c5f36d3a7c (3.6.6)
1+
## Copied from BiaPy commit: b321fbf3700bec39480dde84fc2d37e4081c9581 (3.6.7)
2+
"""
3+
Configuration checking utilities for BiaPy.
4+
5+
This module provides functions to validate, compare, and update BiaPy configuration
6+
objects, ensuring that all required settings are present and consistent for a given
7+
workflow. It includes compatibility checks for data, model, augmentation, and
8+
post-processing options.
9+
"""
210
import os
311
import re
412
from typing import List, Tuple, Any, Dict
@@ -156,6 +164,24 @@ def sort_key(item):
156164
if cfg.PROBLEM.INSTANCE_SEG.WATERSHED.GROWTH_MASK_CHANNELS == []:
157165
growth_mask_channels = ["F"]
158166
growth_mask_channel_ths = ["auto"]
167+
elif set(sorted_original_instance_channels) == {"F", "Dc"}:
168+
if cfg.PROBLEM.INSTANCE_SEG.WATERSHED.SEED_CHANNELS == []:
169+
seed_channels = ["F", "Dc"]
170+
seed_channels_thresh = ["auto", "auto"]
171+
if cfg.PROBLEM.INSTANCE_SEG.WATERSHED.TOPOGRAPHIC_SURFACE_CHANNEL == "":
172+
topo_surface_ch = "F"
173+
if cfg.PROBLEM.INSTANCE_SEG.WATERSHED.GROWTH_MASK_CHANNELS == []:
174+
growth_mask_channels = ["F"]
175+
growth_mask_channel_ths = ["auto"]
176+
elif set(sorted_original_instance_channels) == {"F", "Dn"}:
177+
if cfg.PROBLEM.INSTANCE_SEG.WATERSHED.SEED_CHANNELS == []:
178+
seed_channels = ["F", "Dn"]
179+
seed_channels_thresh = ["auto", "auto"]
180+
if cfg.PROBLEM.INSTANCE_SEG.WATERSHED.TOPOGRAPHIC_SURFACE_CHANNEL == "":
181+
topo_surface_ch = "F"
182+
if cfg.PROBLEM.INSTANCE_SEG.WATERSHED.GROWTH_MASK_CHANNELS == []:
183+
growth_mask_channels = ["F"]
184+
growth_mask_channel_ths = ["auto"]
159185
elif set(sorted_original_instance_channels) == {"F", "P"}:
160186
if cfg.PROBLEM.INSTANCE_SEG.WATERSHED.SEED_CHANNELS == []:
161187
seed_channels = ["F", "P"]
@@ -386,7 +412,7 @@ def sort_key(item):
386412
# Dc — center/skeleton distance-to-center
387413
if "Dc" in chs:
388414
dst["Dc"] = {
389-
"type": dst.get("Dc", {}).get("mode", "thick"),
415+
"type": dst.get("Dc", {}).get("mode", "centroid"),
390416
"norm": dst.get("Dc", {}).get("norm", True),
391417
"mask_values": dst.get("Dc", {}).get("mask_values", True),
392418
}
@@ -1031,8 +1057,14 @@ def sort_key(item):
10311057
"W_CE_DICE",
10321058
], "LOSS.TYPE not in ['CE', 'DICE', 'W_CE_DICE']"
10331059

1034-
if cfg.DATA.N_CLASSES > 2 and loss != "CE":
1035-
raise ValueError("'DATA.N_CLASSES' can only be done with 'CE' loss")
1060+
if cfg.DATA.N_CLASSES > 2:
1061+
if loss != "CE":
1062+
raise ValueError("'DATA.N_CLASSES' are only used with 'CE' loss and not with {}".format(loss))
1063+
if cfg.LOSS.CLASS_REBALANCE == "auto":
1064+
raise ValueError(
1065+
"'LOSS.CLASS_REBALANCE' can not be set to 'auto' when 'DATA.N_CLASSES' > 2 as it is only valid for binary problems. " \
1066+
"Use 'manual' and 'LOSS.CLASS_WEIGHTS' if you really want to rebalance classes. If not, set 'LOSS.CLASS_REBALANCE' to 'none'."
1067+
)
10361068
if loss == "W_CE_DICE":
10371069
assert (
10381070
len(cfg.LOSS.WEIGHTS) == 2
@@ -1314,8 +1346,8 @@ def sort_key(item):
13141346
assert isinstance(val["norm"], bool)
13151347
_assert_bool(val, "mask_values", ctx)
13161348

1317-
elif key == "Dc": # distance-to-center
1318-
_assert_str_in(val, "type", {"center", "skeleton"}, ctx)
1349+
elif key == "Dc": # distance-to-centroid
1350+
_assert_str_in(val, "type", {"centroid", "skeleton"}, ctx)
13191351
_assert_optional_bool(val, "norm", ctx)
13201352
_assert_bool(val, "mask_values", ctx)
13211353

@@ -2259,14 +2291,27 @@ def sort_key(item):
22592291

22602292
# Adjust Z_DOWN values to feature maps
22612293
if all(x == 0 for x in cfg.MODEL.Z_DOWN):
2262-
opts.extend(["MODEL.Z_DOWN", (2,) * (len(cfg.MODEL.FEATURE_MAPS) - 1)])
2294+
if model_arch == "multiresunet":
2295+
opts.extend(["MODEL.Z_DOWN", (2, 2, 2, 2)])
2296+
else:
2297+
opts.extend(["MODEL.Z_DOWN", (2,) * (len(cfg.MODEL.FEATURE_MAPS) - 1)])
22632298
elif any([False for x in cfg.MODEL.Z_DOWN if x != 1 and x != 2]):
22642299
raise ValueError("'MODEL.Z_DOWN' needs to be 1 or 2")
22652300
else:
22662301
if model_arch == "multiresunet" and len(cfg.MODEL.Z_DOWN) != 4:
22672302
raise ValueError("'MODEL.Z_DOWN' length must be 4 when using 'multiresunet'")
2268-
elif len(cfg.MODEL.FEATURE_MAPS) - 1 != len(cfg.MODEL.Z_DOWN):
2269-
raise ValueError("'MODEL.FEATURE_MAPS' length minus one and 'MODEL.Z_DOWN' length must be equal")
2303+
elif model_arch in [
2304+
"unet",
2305+
"resunet",
2306+
"resunet++",
2307+
"seunet",
2308+
"resunet_se",
2309+
"attention_unet",
2310+
"unext_v1",
2311+
"unext_v2",
2312+
]:
2313+
if len(cfg.MODEL.FEATURE_MAPS) - 1 != len(cfg.MODEL.Z_DOWN):
2314+
raise ValueError("'MODEL.FEATURE_MAPS' length minus one and 'MODEL.Z_DOWN' length must be equal")
22702315

22712316
# Adjust ISOTROPY values to feature maps
22722317
if all(x == True for x in cfg.MODEL.ISOTROPY):
@@ -3275,7 +3320,6 @@ def convert_old_model_cfg_to_current_version(old_cfg: dict):
32753320

32763321
return old_cfg
32773322

3278-
32793323
# Function extracted from check_configuration checks
32803324
def check_torchvision_available_models(workflow: str, ndim: str) -> Tuple[List[str], List[str], List[Dict[str, Any]]]:
32813325
"""

biapy/biapy_config.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
## Copied from BiaPy commit: b1c40e7bf89cb12a6413ef8dc55f49f1d1a45c34 (3.6.6)
1+
## Copied from BiaPy commit: b321fbf3700bec39480dde84fc2d37e4081c9581 (3.6.7)
22
"""
33
Configuration management for BiaPy.
44
@@ -154,7 +154,7 @@ def __init__(self, job_dir: str, job_identifier: str):
154154
# - 'norm': bool, whether to normalize the distances between 0 and 1. Default: False
155155
# - 'mask_values': bool, whether to mask the distance channel to only calculate the loss in non-zero values. Default: True
156156
# - 'Dc' channel. Possible options:
157-
# - 'type': str, the type of the channel. Options are: 'center', 'skeleton'. Default: 'center'
157+
# - 'type': str, the type of the channel. Options are: 'centroid', 'skeleton'. Default: 'centroid'
158158
# - 'norm': bool, whether to normalize the distances between 0 and 1. Default: False
159159
# - 'mask_values': bool, whether to mask the distance channel to only calculate the loss in non-zero values. Default: True
160160
# - 'Dn' channel. Possible options:
@@ -1403,11 +1403,14 @@ def __init__(self, job_dir: str, job_identifier: str):
14031403
# It works for all the workflows but the instance segmentation one, as in that case the weights must be set
14041404
# in PROBLEM.INSTANCE_SEG.DATA_CHANNEL_WEIGHTS. The weights must sum 1. E.g. [0.3, 0.7].
14051405
_C.LOSS.WEIGHTS = [0.66, 0.34]
1406-
# To weight classes in an imbalanced dataset. It can be 'none', 'manual' or 'auto'.
1407-
# Options:
1406+
# To weight classes in an imbalanced dataset. Options available are:
14081407
# * 'none': no class rebalancing is applied
14091408
# * 'manual': the weights provided in LOSS.CLASS_WEIGHTS are used to weight each class
1410-
# * 'auto': the weights are calculated automatically based on the number of pixels of each class per batch and directly in the loss computation.
1409+
# * 'auto': the weights are calculated automatically based on the number of pixels of each class per batch and directly in the loss computation.
1410+
# This option is only applied for binary clases. That is to say:
1411+
# * When LOSS.TYPE == "CE" in semantic segmentation and detection workflows and MODEL.N_CLASSES == 2.
1412+
# * In instance segmentation when PROBLEM.INSTANCE_SEG.DATA_CHANNELS_LOSSES contains "CE". This is automatically set
1413+
# when using binary channels, such as "B","F","P","C","T","A","M","F_pre","F_post".
14111414
_C.LOSS.CLASS_REBALANCE = "none" # Options are 'none', 'manual' or 'auto'
14121415
# If LOSS.CLASS_REBALANCE is set to 'manual', this list of weights will be used to weight each class in the loss calculation.
14131416
# The length of the list must be equal to the number of classes.

0 commit comments

Comments
 (0)