Skip to content

Commit fe1d60f

Browse files
author
Donglai Wei
committed
monai2d_worm
1 parent 3f23c4f commit fe1d60f

File tree

12 files changed

+693
-451
lines changed

12 files changed

+693
-451
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ Choose your preferred method:
4747
<summary><b>🚀 One-Command Install (Recommended)</b></summary>
4848

4949
```bash
50-
curl -fsSL https://raw.githubusercontent.com/zudi-lin/pytorch_connectomics/v2.0/quickstart.sh | bash
50+
curl -fsSL https://raw.githubusercontent.com/zudi-lin/pytorch_connectomics/refs/heads/master/quickstart.sh | bash
5151
conda activate pytc
5252
```
5353

connectomics/config/hydra_config.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -287,9 +287,24 @@ class LabelTransformConfig:
287287
targets: List[Any] = field(default_factory=list)
288288

289289

290+
@dataclass
291+
class DataTransformConfig:
292+
"""Data transformation configuration applied to all data (image/label/mask).
293+
294+
These transforms are applied to paired data (image, label, mask) to ensure
295+
spatial alignment. Each transform uses appropriate interpolation:
296+
- Image: bilinear interpolation (smooth)
297+
- Label/Mask: nearest-neighbor interpolation (preserves integer values)
298+
"""
299+
300+
resize: Optional[List[float]] = (
301+
None # Resize to target size [H, W] for 2D or [D, H, W] for 3D. None = no resize.
302+
)
303+
304+
290305
@dataclass
291306
class ImageTransformConfig:
292-
"""Image transformation configuration."""
307+
"""Image transformation configuration (applied to image only)."""
293308

294309
normalize: str = "0-1" # "none", "normal" (z-score), or "0-1" (min-max)
295310
clip_percentile_low: float = (
@@ -298,9 +313,6 @@ class ImageTransformConfig:
298313
clip_percentile_high: float = (
299314
1.0 # Upper percentile for clipping (1.0 = no clip, 0.95 = 95th percentile)
300315
)
301-
resize: Optional[List[float]] = (
302-
None # Resize factors [H_scale, W_scale] for 2D or [D_scale, H_scale, W_scale] for 3D. None = no resize. Uses bilinear interpolation for images.
303-
)
304316

305317

306318
@dataclass
@@ -389,7 +401,10 @@ class DataConfig:
389401
use_cache: bool = False
390402
cache_rate: float = 1.0
391403

392-
# Image transformation
404+
# Data transformation (applied to image/label/mask)
405+
data_transform: DataTransformConfig = field(default_factory=DataTransformConfig)
406+
407+
# Image transformation (applied to image only)
393408
image_transform: ImageTransformConfig = field(default_factory=ImageTransformConfig)
394409

395410
# Sampling (for volumetric datasets)

connectomics/config/hydra_utils.py

Lines changed: 40 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -42,34 +42,31 @@ def load_config(config_path: Union[str, Path]) -> Config:
4242
def save_config(cfg: Config, save_path: Union[str, Path]) -> None:
4343
"""
4444
Save configuration to YAML file.
45-
45+
4646
Args:
4747
cfg: Config object to save
4848
save_path: Path where to save the YAML file
4949
"""
5050
save_path = Path(save_path)
5151
save_path.parent.mkdir(parents=True, exist_ok=True)
52-
52+
5353
omega_conf = OmegaConf.structured(cfg)
5454
OmegaConf.save(omega_conf, save_path)
5555

5656

57-
def merge_configs(
58-
base_cfg: Config,
59-
*override_cfgs: Union[Config, Dict, str, Path]
60-
) -> Config:
57+
def merge_configs(base_cfg: Config, *override_cfgs: Union[Config, Dict, str, Path]) -> Config:
6158
"""
6259
Merge multiple configurations together.
63-
60+
6461
Args:
6562
base_cfg: Base configuration
6663
*override_cfgs: One or more override configs (Config, dict, or path to YAML)
67-
64+
6865
Returns:
6966
Merged Config object
7067
"""
7168
result = OmegaConf.structured(base_cfg)
72-
69+
7370
for override_cfg in override_cfgs:
7471
if isinstance(override_cfg, (str, Path)):
7572
override_omega = OmegaConf.load(override_cfg)
@@ -79,22 +76,22 @@ def merge_configs(
7976
override_omega = OmegaConf.create(override_cfg)
8077
else:
8178
raise TypeError(f"Unsupported config type: {type(override_cfg)}")
82-
79+
8380
result = OmegaConf.merge(result, override_omega)
84-
81+
8582
return OmegaConf.to_object(result)
8683

8784

8885
def update_from_cli(cfg: Config, overrides: List[str]) -> Config:
8986
"""
9087
Update config from command-line overrides.
91-
88+
9289
Supports dot notation: ['data.batch_size=4', 'model.architecture=unet3d']
93-
90+
9491
Args:
9592
cfg: Base Config object
9693
overrides: List of 'key=value' strings
97-
94+
9895
Returns:
9996
Updated Config object
10097
"""
@@ -107,11 +104,11 @@ def update_from_cli(cfg: Config, overrides: List[str]) -> Config:
107104
def to_dict(cfg: Config, resolve: bool = True) -> Dict[str, Any]:
108105
"""
109106
Convert Config to dictionary.
110-
107+
111108
Args:
112109
cfg: Config object
113110
resolve: Whether to resolve variable interpolations
114-
111+
115112
Returns:
116113
Dictionary representation
117114
"""
@@ -122,10 +119,10 @@ def to_dict(cfg: Config, resolve: bool = True) -> Dict[str, Any]:
122119
def from_dict(d: Dict[str, Any]) -> Config:
123120
"""
124121
Create Config from dictionary.
125-
122+
126123
Args:
127124
d: Dictionary with configuration values
128-
125+
129126
Returns:
130127
Config object
131128
"""
@@ -138,7 +135,7 @@ def from_dict(d: Dict[str, Any]) -> Config:
138135
def print_config(cfg: Config, resolve: bool = True) -> None:
139136
"""
140137
Pretty print configuration.
141-
138+
142139
Args:
143140
cfg: Config to print
144141
resolve: Whether to resolve variable interpolations
@@ -150,10 +147,10 @@ def print_config(cfg: Config, resolve: bool = True) -> None:
150147
def validate_config(cfg: Config) -> None:
151148
"""
152149
Validate configuration values.
153-
150+
154151
Args:
155152
cfg: Config object to validate
156-
153+
157154
Raises:
158155
ValueError: If configuration is invalid
159156
"""
@@ -162,18 +159,18 @@ def validate_config(cfg: Config) -> None:
162159
raise ValueError("model.in_channels must be positive")
163160
if cfg.model.out_channels <= 0:
164161
raise ValueError("model.out_channels must be positive")
165-
if len(cfg.model.input_size) != 3:
166-
raise ValueError("model.input_size must be 3D")
167-
162+
if len(cfg.model.input_size) not in [2, 3]:
163+
raise ValueError("model.input_size must be 2D or 3D (got length {})".format(len(cfg.model.input_size)))
164+
168165
# System validation
169166
if cfg.system.training.batch_size <= 0:
170167
raise ValueError("system.training.batch_size must be positive")
171168
if cfg.system.training.num_workers < 0:
172169
raise ValueError("system.training.num_workers must be non-negative")
173170

174171
# Data validation
175-
if len(cfg.data.patch_size) != 3:
176-
raise ValueError("data.patch_size must be 3D")
172+
if len(cfg.data.patch_size) not in [2, 3]:
173+
raise ValueError("data.patch_size must be 2D or 3D (got length {})".format(len(cfg.data.patch_size)))
177174

178175
# Optimizer validation
179176
if cfg.optimization.optimizer.lr <= 0:
@@ -188,7 +185,7 @@ def validate_config(cfg: Config) -> None:
188185
raise ValueError("optimization.gradient_clip_val must be non-negative")
189186
if cfg.optimization.accumulate_grad_batches <= 0:
190187
raise ValueError("optimization.accumulate_grad_batches must be positive")
191-
188+
192189
# Loss validation
193190
if len(cfg.model.loss_functions) != len(cfg.model.loss_weights):
194191
raise ValueError("loss_functions and loss_weights must have same length")
@@ -199,16 +196,17 @@ def validate_config(cfg: Config) -> None:
199196
def get_config_hash(cfg: Config) -> str:
200197
"""
201198
Generate a hash string for the configuration.
202-
199+
203200
Useful for experiment tracking and reproducibility.
204-
201+
205202
Args:
206203
cfg: Config object
207-
204+
208205
Returns:
209206
Hash string
210207
"""
211208
import hashlib
209+
212210
omega_conf = OmegaConf.structured(cfg)
213211
yaml_str = OmegaConf.to_yaml(omega_conf, resolve=True)
214212
return hashlib.md5(yaml_str.encode()).hexdigest()[:8]
@@ -228,20 +226,20 @@ def create_experiment_name(cfg: Config) -> str:
228226
cfg.model.architecture,
229227
f"bs{cfg.system.training.batch_size}",
230228
f"lr{cfg.optimization.optimizer.lr:.0e}",
231-
get_config_hash(cfg)
229+
get_config_hash(cfg),
232230
]
233231
return "_".join(parts)
234232

235233

236234
__all__ = [
237-
'load_config',
238-
'save_config',
239-
'merge_configs',
240-
'update_from_cli',
241-
'to_dict',
242-
'from_dict',
243-
'print_config',
244-
'validate_config',
245-
'get_config_hash',
246-
'create_experiment_name',
247-
]
235+
"load_config",
236+
"save_config",
237+
"merge_configs",
238+
"update_from_cli",
239+
"to_dict",
240+
"from_dict",
241+
"print_config",
242+
"validate_config",
243+
"get_config_hash",
244+
"create_experiment_name",
245+
]

0 commit comments

Comments
 (0)