|
1 | | -## Copied from BiaPy commit: 3db4edcf634b3726484d650bc03058c5f36d3a7c (3.6.6) |
| 1 | +## Copied from BiaPy commit: b321fbf3700bec39480dde84fc2d37e4081c9581 (3.6.7) |
| 2 | +""" |
| 3 | +Configuration checking utilities for BiaPy. |
| 4 | +
|
| 5 | +This module provides functions to validate, compare, and update BiaPy configuration |
| 6 | +objects, ensuring that all required settings are present and consistent for a given |
| 7 | +workflow. It includes compatibility checks for data, model, augmentation, and |
| 8 | +post-processing options. |
| 9 | +""" |
2 | 10 | import os |
3 | 11 | import re |
4 | 12 | from typing import List, Tuple, Any, Dict |
@@ -156,6 +164,24 @@ def sort_key(item): |
156 | 164 | if cfg.PROBLEM.INSTANCE_SEG.WATERSHED.GROWTH_MASK_CHANNELS == []: |
157 | 165 | growth_mask_channels = ["F"] |
158 | 166 | growth_mask_channel_ths = ["auto"] |
| 167 | + elif set(sorted_original_instance_channels) == {"F", "Dc"}: |
| 168 | + if cfg.PROBLEM.INSTANCE_SEG.WATERSHED.SEED_CHANNELS == []: |
| 169 | + seed_channels = ["F", "Dc"] |
| 170 | + seed_channels_thresh = ["auto", "auto"] |
| 171 | + if cfg.PROBLEM.INSTANCE_SEG.WATERSHED.TOPOGRAPHIC_SURFACE_CHANNEL == "": |
| 172 | + topo_surface_ch = "F" |
| 173 | + if cfg.PROBLEM.INSTANCE_SEG.WATERSHED.GROWTH_MASK_CHANNELS == []: |
| 174 | + growth_mask_channels = ["F"] |
| 175 | + growth_mask_channel_ths = ["auto"] |
| 176 | + elif set(sorted_original_instance_channels) == {"F", "Dn"}: |
| 177 | + if cfg.PROBLEM.INSTANCE_SEG.WATERSHED.SEED_CHANNELS == []: |
| 178 | + seed_channels = ["F", "Dn"] |
| 179 | + seed_channels_thresh = ["auto", "auto"] |
| 180 | + if cfg.PROBLEM.INSTANCE_SEG.WATERSHED.TOPOGRAPHIC_SURFACE_CHANNEL == "": |
| 181 | + topo_surface_ch = "F" |
| 182 | + if cfg.PROBLEM.INSTANCE_SEG.WATERSHED.GROWTH_MASK_CHANNELS == []: |
| 183 | + growth_mask_channels = ["F"] |
| 184 | + growth_mask_channel_ths = ["auto"] |
159 | 185 | elif set(sorted_original_instance_channels) == {"F", "P"}: |
160 | 186 | if cfg.PROBLEM.INSTANCE_SEG.WATERSHED.SEED_CHANNELS == []: |
161 | 187 | seed_channels = ["F", "P"] |
@@ -386,7 +412,7 @@ def sort_key(item): |
386 | 412 | # Dc — center/skeleton distance-to-center |
387 | 413 | if "Dc" in chs: |
388 | 414 | dst["Dc"] = { |
389 | | - "type": dst.get("Dc", {}).get("mode", "thick"), |
| 415 | + "type": dst.get("Dc", {}).get("mode", "centroid"), |
390 | 416 | "norm": dst.get("Dc", {}).get("norm", True), |
391 | 417 | "mask_values": dst.get("Dc", {}).get("mask_values", True), |
392 | 418 | } |
@@ -1031,8 +1057,14 @@ def sort_key(item): |
1031 | 1057 | "W_CE_DICE", |
1032 | 1058 | ], "LOSS.TYPE not in ['CE', 'DICE', 'W_CE_DICE']" |
1033 | 1059 |
|
1034 | | - if cfg.DATA.N_CLASSES > 2 and loss != "CE": |
1035 | | - raise ValueError("'DATA.N_CLASSES' can only be done with 'CE' loss") |
| 1060 | + if cfg.DATA.N_CLASSES > 2: |
| 1061 | + if loss != "CE": |
| 1062 | + raise ValueError("'DATA.N_CLASSES' are only used with 'CE' loss and not with {}".format(loss)) |
| 1063 | + if cfg.LOSS.CLASS_REBALANCE == "auto": |
| 1064 | + raise ValueError( |
| 1065 | + "'LOSS.CLASS_REBALANCE' can not be set to 'auto' when 'DATA.N_CLASSES' > 2 as it is only valid for binary problems. " \ |
| 1066 | + "Use 'manual' and 'LOSS.CLASS_WEIGHTS' if you really want to rebalance classes. If not, set 'LOSS.CLASS_REBALANCE' to 'none'." |
| 1067 | + ) |
1036 | 1068 | if loss == "W_CE_DICE": |
1037 | 1069 | assert ( |
1038 | 1070 | len(cfg.LOSS.WEIGHTS) == 2 |
@@ -1314,8 +1346,8 @@ def sort_key(item): |
1314 | 1346 | assert isinstance(val["norm"], bool) |
1315 | 1347 | _assert_bool(val, "mask_values", ctx) |
1316 | 1348 |
|
1317 | | - elif key == "Dc": # distance-to-center |
1318 | | - _assert_str_in(val, "type", {"center", "skeleton"}, ctx) |
| 1349 | + elif key == "Dc": # distance-to-centroid |
| 1350 | + _assert_str_in(val, "type", {"centroid", "skeleton"}, ctx) |
1319 | 1351 | _assert_optional_bool(val, "norm", ctx) |
1320 | 1352 | _assert_bool(val, "mask_values", ctx) |
1321 | 1353 |
|
@@ -2259,14 +2291,27 @@ def sort_key(item): |
2259 | 2291 |
|
2260 | 2292 | # Adjust Z_DOWN values to feature maps |
2261 | 2293 | if all(x == 0 for x in cfg.MODEL.Z_DOWN): |
2262 | | - opts.extend(["MODEL.Z_DOWN", (2,) * (len(cfg.MODEL.FEATURE_MAPS) - 1)]) |
| 2294 | + if model_arch == "multiresunet": |
| 2295 | + opts.extend(["MODEL.Z_DOWN", (2, 2, 2, 2)]) |
| 2296 | + else: |
| 2297 | + opts.extend(["MODEL.Z_DOWN", (2,) * (len(cfg.MODEL.FEATURE_MAPS) - 1)]) |
2263 | 2298 | elif any([False for x in cfg.MODEL.Z_DOWN if x != 1 and x != 2]): |
2264 | 2299 | raise ValueError("'MODEL.Z_DOWN' needs to be 1 or 2") |
2265 | 2300 | else: |
2266 | 2301 | if model_arch == "multiresunet" and len(cfg.MODEL.Z_DOWN) != 4: |
2267 | 2302 | raise ValueError("'MODEL.Z_DOWN' length must be 4 when using 'multiresunet'") |
2268 | | - elif len(cfg.MODEL.FEATURE_MAPS) - 1 != len(cfg.MODEL.Z_DOWN): |
2269 | | - raise ValueError("'MODEL.FEATURE_MAPS' length minus one and 'MODEL.Z_DOWN' length must be equal") |
| 2303 | + elif model_arch in [ |
| 2304 | + "unet", |
| 2305 | + "resunet", |
| 2306 | + "resunet++", |
| 2307 | + "seunet", |
| 2308 | + "resunet_se", |
| 2309 | + "attention_unet", |
| 2310 | + "unext_v1", |
| 2311 | + "unext_v2", |
| 2312 | + ]: |
| 2313 | + if len(cfg.MODEL.FEATURE_MAPS) - 1 != len(cfg.MODEL.Z_DOWN): |
| 2314 | + raise ValueError("'MODEL.FEATURE_MAPS' length minus one and 'MODEL.Z_DOWN' length must be equal") |
2270 | 2315 |
|
2271 | 2316 | # Adjust ISOTROPY values to feature maps |
2272 | 2317 | if all(x == True for x in cfg.MODEL.ISOTROPY): |
@@ -3275,7 +3320,6 @@ def convert_old_model_cfg_to_current_version(old_cfg: dict): |
3275 | 3320 |
|
3276 | 3321 | return old_cfg |
3277 | 3322 |
|
3278 | | - |
3279 | 3323 | # Function extracted from check_configuration checks |
3280 | 3324 | def check_torchvision_available_models(workflow: str, ndim: str) -> Tuple[List[str], List[str], List[Dict[str, Any]]]: |
3281 | 3325 | """ |
|
0 commit comments