66import pathlib
77import sys
88
9- import yaml
9+ from spine . utils . config import load_config , parse_value , set_nested_value
1010
1111
1212def main (
@@ -16,10 +16,10 @@ def main(
1616 output ,
1717 n ,
1818 nskip ,
19- detect_anomaly ,
2019 log_dir ,
2120 weight_prefix ,
2221 weight_path ,
22+ config_overrides ,
2323):
2424 """Main driver for training/validation/inference/analysis.
2525
@@ -41,15 +41,15 @@ def main(
4141 Number of iterations to run
4242 nskip : int
4343 Number of iterations to skip
44- detect_anomaly : bool
45- Whether to turn on anomaly detection in torch
4644 log_dir : str
4745 Path to the directory for storing the training log
4846 weight_prefix : str
4947 Path to the directory for storing the training weights
5048 weight_path : str
5149 Path string a weight file or pattern for multiple weight files to load
5250 the model weights
51+ config_overrides : List[str]
52+ List of config overrides in the form "key.path=value"
5353 """
5454 # Try to find configuration file using the absolute path or under
5555 # the 'config' directory relative to the current working directory
@@ -60,9 +60,8 @@ def main(
6060 if not os .path .isfile (cfg_file ):
6161 raise FileNotFoundError (f"Configuration not found: { config } " )
6262
63- # Load the configuration file
64- with open (cfg_file , "r" , encoding = "utf-8" ) as cfg_yaml :
65- cfg = yaml .safe_load (cfg_yaml )
63+ # Load the configuration file using the advanced loader
64+ cfg = load_config (cfg_file )
6665
6766 # If there is no base block, build one
6867 if "base" not in cfg :
@@ -117,12 +116,24 @@ def main(
117116 if weight_path is not None :
118117 cfg ["model" ]["weight_path" ] = weight_path
119118
120- # Turn on PyTorch anomaly detection, if requested
121- if detect_anomaly :
122- assert (
123- "model" in cfg
124- ), "There is no model to detect anomalies for, add `model` block."
125- cfg ["model" ]["detect_anomaly" ] = detect_anomaly
119+ # Apply any generic config overrides from --set arguments
120+ if config_overrides :
121+ for override in config_overrides :
122+ if "=" not in override :
123+ raise ValueError (
124+ f"Invalid --set format: '{ override } '. "
125+ f"Expected format: 'key.path=value'"
126+ )
127+
128+ key_path , value_str = override .split ("=" , 1 )
129+ key_path = key_path .strip ()
130+ value_str = value_str .strip ()
131+
132+ # Parse the value (handles strings, numbers, booleans, lists, etc.)
133+ value = parse_value (value_str )
134+
135+ # Set the nested value
136+ cfg = set_nested_value (cfg , key_path , value )
126137
127138 # For actual training/inference, we need the main functionality
128139 from spine .main import run
@@ -138,10 +149,13 @@ def cli():
138149 formatter_class = argparse .RawDescriptionHelpFormatter ,
139150 epilog = """
140151Examples:
141- spine --version Show version information
142- spine --info Show system and dependency info
143- spine config.cfg Run ML training/inference with config file
144- spine --help Show this help message
152+ spine --version Show version information
153+ spine --info Show system and dependency info
154+ spine -c config.cfg Run ML training/inference with config file
155+ spine -c config.cfg --set io.loader.batch_size=8 Override config parameters
156+ spine -c config.cfg --set base.iterations=1000 --set io.loader.batch_size=16
157+ spine -c config.cfg --set model.detect_anomaly=true Debug PyTorch issues
158+ spine --help Show this help message
145159
146160For ML training/inference functionality, ensure PyTorch is installed:
147161 pip install spine-ml[model]
@@ -187,12 +201,6 @@ def cli():
187201
188202 parser .add_argument ("--nskip" , type = int , help = "Number of iterations to skip" )
189203
190- parser .add_argument (
191- "--detect-anomaly" ,
192- action = "store_true" ,
193- help = "Whether to turn on anomaly detection in torch" ,
194- )
195-
196204 parser .add_argument (
197205 "--log-dir" , help = "Path to the directory for storing the training log"
198206 )
@@ -206,6 +214,16 @@ def cli():
206214 help = "Path string a weight file or pattern for multiple weight files to load model weights" ,
207215 )
208216
217+ parser .add_argument (
218+ "--set" ,
219+ action = "append" ,
220+ dest = "config_overrides" ,
221+ metavar = "KEY=VALUE" ,
222+ help = "Override any config parameter using dot notation "
223+ "(e.g., --set io.loader.batch_size=8). "
224+ "Can be used multiple times for multiple overrides." ,
225+ )
226+
209227 args = parser .parse_args ()
210228
211229 # If no arguments provided and no config, show help
@@ -228,10 +246,10 @@ def cli():
228246 output = args .output ,
229247 n = args .iterations ,
230248 nskip = args .nskip ,
231- detect_anomaly = args .detect_anomaly ,
232249 log_dir = args .log_dir ,
233250 weight_prefix = args .weight_prefix ,
234251 weight_path = args .weight_path ,
252+ config_overrides = args .config_overrides or [],
235253 )
236254
237255
0 commit comments