Skip to content

Commit e6b15b9

Browse files
committed
fix: load global config tree from config file instead of local subtree
1 parent cea24de commit e6b15b9

File tree

3 files changed

+20
-6
lines changed

3 files changed

+20
-6
lines changed

examples/train_config.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1-
epochs: 200
2-
lr: 0.002
1+
train:
2+
epochs: 200
3+
lr: 0.002

src/nanocli/core.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,20 @@ def _execute(self, args: list[str]) -> None:
197197
current._show_command_help(console, part, schema)
198198
return
199199

200-
# Get cfg file
200+
# Get cfg file - load global tree and extract subtree
201201
cfg_file = flags["cfg"]
202-
base = load_yaml(cfg_file) if cfg_file else None
202+
base = None
203+
if cfg_file:
204+
global_cfg = load_yaml(cfg_file)
205+
# Extract subtree based on consumed path
206+
base = global_cfg
207+
for path_key in consumed_path:
208+
if hasattr(base, path_key) or (isinstance(base, dict) and path_key in base):
209+
base = base[path_key]
210+
else:
211+
# Path doesn't exist in config, use None
212+
base = None
213+
break
203214

204215
# Compile config
205216
if schema:

tests/test_core.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,8 +402,9 @@ def test_config_with_yaml(self, tmp_path: Path):
402402
def train(cfg: SimpleConfig):
403403
print(f"Training {cfg.name}")
404404

405+
# Global tree structure - train subtree will be extracted
405406
yaml_file = tmp_path / "config.yml"
406-
yaml_file.write_text("name: from_yaml\ncount: 42")
407+
yaml_file.write_text("train:\n name: from_yaml\n count: 42")
407408

408409
old_stdout = sys.stdout
409410
sys.stdout = io.StringIO()
@@ -622,8 +623,9 @@ def test_print_with_yaml_file(self, tmp_path: Path):
622623
def train(cfg: SimpleConfig):
623624
pass
624625

626+
# Global tree structure
625627
yaml_file = tmp_path / "config.yml"
626-
yaml_file.write_text("name: yaml_test\ncount: 99")
628+
yaml_file.write_text("train:\n name: yaml_test\n count: 99")
627629

628630
old_stdout = sys.stdout
629631
sys.stdout = io.StringIO()

0 commit comments

Comments
 (0)