Skip to content

Commit 124d090

Browse files
QwlouseThe kauldron Authors
authored andcommitted
Kauldron CLI: small UX fixes and conflict resolution
PiperOrigin-RevId: 883656994
1 parent 6e373fb commit 124d090

File tree

4 files changed

+16
-12
lines changed

4 files changed

+16
-12
lines changed

kauldron/cli/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __call__(self):
3535
class Resolve(cu.SubCommand):
3636
"""Resolve and display the fully-instantiated config."""
3737

38-
verbose: bool = False
38+
verbose: bool = True
3939

4040
def __call__(self):
4141
self.print_config_origin()

kauldron/cli/config_test.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,21 @@
1818
from kauldron.cli import config
1919

2020

21-
def test_show_returns_repr(capsys):
22-
"""show() should print the unresolved ConfigDict."""
23-
21+
def test_show_prints_repr(capsys):
2422
cfg = konfig.ConfigDict({"seed": 42, "num_train_steps": 1000})
2523
cmd = config.Show(cfg=cfg)
2624
cmd()
27-
result = capsys.readouterr().out
2825

29-
assert "'seed': 42" in result
30-
assert "'num_train_steps': 1000" in result
26+
captured = capsys.readouterr().out
27+
assert "'seed': 42" in captured
28+
assert "'num_train_steps': 1000" in captured
3129

3230

33-
def test_resolve_simple(capsys):
31+
def test_resolve_prints_config(capsys):
3432
cfg = konfig.ConfigDict({"seed": 42, "num_train_steps": 1000})
3533
cmd = config.Resolve(cfg=cfg, verbose=True)
3634
cmd()
37-
result = capsys.readouterr().out
38-
assert "'seed': 42" in result
39-
assert "'num_train_steps': 1000" in result
35+
36+
captured = capsys.readouterr().out
37+
assert "'seed': 42" in captured
38+
assert "'num_train_steps': 1000" in captured

kauldron/cli/main.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,11 @@ def main(args: Args) -> None:
134134

135135
patched_config, patches = patcher(cfg)
136136

137+
conflicts = set(overrides) & set(patches)
138+
for key in conflicts:
139+
cu.tracked_update(patched_config, key, overrides[key])
140+
del patches[key]
141+
137142
origin = patch_config.ConfigOrigin(
138143
filename=filename,
139144
overrides=overrides,

kauldron/cli/patch_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def _patch_batch_size(self, cfg: konfig.ConfigDict) -> dict[str, Any]:
129129
def _patch_num_batches(self, cfg: konfig.ConfigDict) -> dict[str, Any]:
130130
"""Patches the number of batches."""
131131
updates = {}
132-
if self.num_batches is not None:
132+
if self.num_batches:
133133
if hasattr(cfg, "evals"):
134134
updates |= cu.tracked_update(
135135
cfg, "evals.**.num_batches", self.num_batches

0 commit comments

Comments
 (0)