Skip to content

Commit f846455

Browse files
Add CLI --set flag for generic config overrides and improve config API
- Add --set flag to override any config parameter with dot notation - Remove --detect-anomaly flag (use --set model.detect_anomaly=true instead) - Make config utility functions public (remove underscore prefix) - Add comprehensive type hints to config.py - Switch CLI to use load_config() for consistent config handling - Move config imports to top-level (PEP 8 compliance) - Add CLI override documentation to config_loader.rst - Update help examples to show --set usage
1 parent f73ca9e commit f846455

File tree

3 files changed

+140
-56
lines changed

3 files changed

+140
-56
lines changed

docs/source/config_loader.rst

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,42 @@ Usage in Python
184184
print(cfg['io']['loader']['batch_size']) # 8
185185
print(cfg['model']['modules']['uresnet']['depth']) # 5
186186
187+
Command-Line Overrides
188+
----------------------
189+
190+
When using the SPINE CLI, you can override any configuration parameter using the ``--set`` flag with dot notation:
191+
192+
.. code-block:: bash
193+
194+
# Override a single parameter
195+
spine -c config.yaml --set io.loader.batch_size=8
196+
197+
# Override multiple parameters
198+
spine -c config.yaml \
199+
--set base.iterations=1000 \
200+
--set io.loader.batch_size=16 \
201+
--set io.loader.dataset.file_keys=[file1.root,file2.root]
202+
203+
# Mix with other CLI options
204+
spine -c config.yaml \
205+
--source /data/input.root \
206+
--output /data/output.h5 \
207+
--set model.weight_path=/weights/model.ckpt
208+
209+
The ``--set`` flag accepts any valid YAML value:
210+
211+
- **Strings**: ``--set model.name=my_model``
212+
- **Numbers**: ``--set base.iterations=1000`` or ``--set base.learning_rate=0.001``
213+
- **Booleans**: ``--set io.loader.shuffle=true``
214+
- **Lists**: ``--set io.loader.dataset.file_keys=[file1.root,file2.root]``
215+
- **Nested paths**: ``--set io.loader.dataset.schema.data.num_features=8``
216+
217+
This is particularly useful for:
218+
219+
- **Hyperparameter sweeps**: Quickly test different values without editing config files
220+
- **Production runs**: Override paths and settings for different environments
221+
- **Debugging**: Enable/disable features or adjust batch sizes on the fly
222+
187223
Benefits
188224
--------
189225

src/spine/bin/cli.py

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pathlib
77
import sys
88

9-
import yaml
9+
from spine.utils.config import load_config, parse_value, set_nested_value
1010

1111

1212
def main(
@@ -16,10 +16,10 @@ def main(
1616
output,
1717
n,
1818
nskip,
19-
detect_anomaly,
2019
log_dir,
2120
weight_prefix,
2221
weight_path,
22+
config_overrides,
2323
):
2424
"""Main driver for training/validation/inference/analysis.
2525
@@ -41,15 +41,15 @@ def main(
4141
Number of iterations to run
4242
nskip : int
4343
Number of iterations to skip
44-
detect_anomaly : bool
45-
Whether to turn on anomaly detection in torch
4644
log_dir : str
4745
Path to the directory for storing the training log
4846
weight_prefix : str
4947
Path to the directory for storing the training weights
5048
weight_path : str
5149
Path string a weight file or pattern for multiple weight files to load
5250
the model weights
51+
config_overrides : List[str]
52+
List of config overrides in the form "key.path=value"
5353
"""
5454
# Try to find configuration file using the absolute path or under
5555
# the 'config' directory relative to the current working directory
@@ -60,9 +60,8 @@ def main(
6060
if not os.path.isfile(cfg_file):
6161
raise FileNotFoundError(f"Configuration not found: {config}")
6262

63-
# Load the configuration file
64-
with open(cfg_file, "r", encoding="utf-8") as cfg_yaml:
65-
cfg = yaml.safe_load(cfg_yaml)
63+
# Load the configuration file using the advanced loader
64+
cfg = load_config(cfg_file)
6665

6766
# If there is no base block, build one
6867
if "base" not in cfg:
@@ -117,12 +116,24 @@ def main(
117116
if weight_path is not None:
118117
cfg["model"]["weight_path"] = weight_path
119118

120-
# Turn on PyTorch anomaly detection, if requested
121-
if detect_anomaly:
122-
assert (
123-
"model" in cfg
124-
), "There is no model to detect anomalies for, add `model` block."
125-
cfg["model"]["detect_anomaly"] = detect_anomaly
119+
# Apply any generic config overrides from --set arguments
120+
if config_overrides:
121+
for override in config_overrides:
122+
if "=" not in override:
123+
raise ValueError(
124+
f"Invalid --set format: '{override}'. "
125+
f"Expected format: 'key.path=value'"
126+
)
127+
128+
key_path, value_str = override.split("=", 1)
129+
key_path = key_path.strip()
130+
value_str = value_str.strip()
131+
132+
# Parse the value (handles strings, numbers, booleans, lists, etc.)
133+
value = parse_value(value_str)
134+
135+
# Set the nested value
136+
cfg = set_nested_value(cfg, key_path, value)
126137

127138
# For actual training/inference, we need the main functionality
128139
from spine.main import run
@@ -138,10 +149,13 @@ def cli():
138149
formatter_class=argparse.RawDescriptionHelpFormatter,
139150
epilog="""
140151
Examples:
141-
spine --version Show version information
142-
spine --info Show system and dependency info
143-
spine config.cfg Run ML training/inference with config file
144-
spine --help Show this help message
152+
spine --version Show version information
153+
spine --info Show system and dependency info
154+
spine -c config.cfg Run ML training/inference with config file
155+
spine -c config.cfg --set io.loader.batch_size=8 Override config parameters
156+
spine -c config.cfg --set base.iterations=1000 --set io.loader.batch_size=16
157+
spine -c config.cfg --set model.detect_anomaly=true Debug PyTorch issues
158+
spine --help Show this help message
145159
146160
For ML training/inference functionality, ensure PyTorch is installed:
147161
pip install spine-ml[model]
@@ -187,12 +201,6 @@ def cli():
187201

188202
parser.add_argument("--nskip", type=int, help="Number of iterations to skip")
189203

190-
parser.add_argument(
191-
"--detect-anomaly",
192-
action="store_true",
193-
help="Whether to turn on anomaly detection in torch",
194-
)
195-
196204
parser.add_argument(
197205
"--log-dir", help="Path to the directory for storing the training log"
198206
)
@@ -206,6 +214,16 @@ def cli():
206214
help="Path string a weight file or pattern for multiple weight files to load model weights",
207215
)
208216

217+
parser.add_argument(
218+
"--set",
219+
action="append",
220+
dest="config_overrides",
221+
metavar="KEY=VALUE",
222+
help="Override any config parameter using dot notation "
223+
"(e.g., --set io.loader.batch_size=8). "
224+
"Can be used multiple times for multiple overrides.",
225+
)
226+
209227
args = parser.parse_args()
210228

211229
# If no arguments provided and no config, show help
@@ -228,10 +246,10 @@ def cli():
228246
output=args.output,
229247
n=args.iterations,
230248
nskip=args.nskip,
231-
detect_anomaly=args.detect_anomaly,
232249
log_dir=args.log_dir,
233250
weight_prefix=args.weight_prefix,
234251
weight_path=args.weight_path,
252+
config_overrides=args.config_overrides or [],
235253
)
236254

237255

0 commit comments

Comments
 (0)