-
Notifications
You must be signed in to change notification settings - Fork 51
clean-up in config.py focusing on shared path #1579
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
grassesi
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work, at some points there is potential for even more simplification. Please have a look at my suggestions.
| # Cache the expensive private config loading operation | ||
| _shared_wg_base_path = None | ||
|
|
||
|
|
||
| def _get_shared_wg_base_path() -> Path: | ||
| """Get the shared working directory base path, cached after first call.""" | ||
| global _shared_wg_base_path | ||
| if _shared_wg_base_path is None: | ||
| pcfg = _load_private_conf() | ||
| _shared_wg_base_path = Path(pcfg.get("path_shared_working_dir")) | ||
| return _shared_wg_base_path |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
python implements a super simple mechanism to cache results: You can annotate a method with @cache:
@cache
def _get_shared_wg_base_path() -> Path:
"""Get the shared working directory base path, cached after first call."""
private_config = _load_private_conf()
return Path(private_config.get("path_shared_working_dir"))|
|
||
| def set_paths(config: Config) -> Config: | ||
| """Set the configs run_path model_path attributes to default values if not present.""" | ||
| config = config.copy() | ||
| config.run_path = _get_config_attribute( | ||
| config=config, attribute_name="run_path", fallback="results" | ||
| ) | ||
| config.model_path = _get_config_attribute( | ||
| config=config, attribute_name="model_path", fallback="models" | ||
| ) | ||
|
|
||
| return config | ||
|
|
||
|
|
||
| def _get_config_attribute(config: Config, attribute_name: str, fallback: str) -> str: | ||
| """Get an attribute from a Config. If not available, fall back to path_shared_working_dir | ||
| concatenated with the desired fallback path. Raise an error if neither the attribute nor a | ||
| fallback is specified.""" | ||
| attribute = OmegaConf.select(config, attribute_name) | ||
| fallback_root = OmegaConf.select(config, "path_shared_working_dir") | ||
| assert attribute is not None or fallback_root is not None, ( | ||
| f"Must specify `{attribute_name}` in config if `path_shared_working_dir` is None in config" | ||
| ) | ||
| attribute = attribute if attribute else fallback_root + fallback | ||
| return attribute | ||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice, I am glad that this confusing logic is gone.
|
|
||
| result_dir_base = Path(cf.run_path) | ||
| result_dir_base = config._get_shared_wg_path("results") | ||
| result_dir = result_dir_base / run_id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please use result_dir = config.get_path_run(cf) here
| metrics_path = get_train_metrics_path( | ||
| base_path=Path(self.cf.run_path), run_id=self.cf.run_id | ||
| base_path=config._get_shared_wg_path("results"), run_id=self.cf.run_id | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use config.get_path_run(self.cf) here.
|
|
||
| result_dir_base = Path(cf.run_path) | ||
| result_dir_base = config._get_shared_wg_path("results") | ||
| result_dir = result_dir_base / run_id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above.
| """ | ||
| pcfg = _load_private_conf() | ||
| return Path(pcfg.get("path_shared_working_dir")) / local_path | ||
| return _get_shared_wg_base_path() / local_path |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no need for this method. Everywhere outside of this module config.get_path_...(cf) should be used. Inside this module this method is just a more indirect way of saying directly _get_shared_wg_base_path() / local_path. Please remove this method and instead rename _get_shared_wg_base_path() to _get_shared_wg_path(). Use this method then to implement get_path_run and get_path_model
| model_path = str(_get_shared_wg_path("models")) | ||
| path = Path(model_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use model_path = _get_shared_wg_path() / "models" here. model_path does not need to be a str, see my comments on get_shared_wg_path.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you implement the suggestion of my comment on get_path_model you could even do model_path = get_path_model(run_id=run_id).
| def get_path_model(config: Config) -> Path: | ||
| """Get the current runs model_path for storing model checkpoints.""" | ||
| return Path(config.model_path) / config.run_id | ||
| model_path = _get_shared_wg_path("models") | ||
| return model_path / config.run_id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be cool if this method accepts either a config object or a run_id directly eg:
def get_path_model(config: Config | None = None, run_id: str | None) -> Path:
if config or run_id:
run_id = run_id if run_id else config.run_id
else:
msg = f"Missing run_id and cannot infer it from config: {config}"
raise ValueError(msg)
return _get_shared_wg_path() / "models" / run_idThen we could use it in load_run_config / _get_model_config_file_read_name: get_path_model(run_id=run_id)
| dirname = path_models / config.run_id | ||
| dirname = get_path_model(config) | ||
| dirname.mkdir(exist_ok=True, parents=True) | ||
|
|
||
| fname = _get_model_config_file_write_name(path_models, config.run_id, mini_epoch) | ||
| path_models_parent = dirname.parent | ||
| fname = _get_model_config_file_write_name(path_models_parent, config.run_id, mini_epoch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please adjust _get_model_config_file_write_name and get_model_config_file_read_name to only return the filename not the entire path of the config_file eg:
dirname = get_path_model(config)
dirname.mkdir(exist_ok=True, parents=True)
fname = _get_model_config_file_write_name(config.run_id, mini_epoch)
json_str = json.dumps(OmegaConf.to_container(_strip_interpolation(config)))
with (dirname/fname).open("w") as f:
f.write(json_str)| base_config = load_run_config( | ||
| from_run_id, mini_epoch, private_config.get("model_path", None) | ||
| ) | ||
| base_config = load_run_config(from_run_id, mini_epoch, _get_shared_wg_path("models")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need for _get_shared_wg_path() / "models" here since it will be retrieved again in load_run_config anyway.
- Simplify get_path_model/get_path_run to always resolve via _get_shared_wg_path()
- Change _get_shared_wg_path() to cached, argument-free helper returning the shared working dir from private config
- Adjust model config save/load to build filenames relative to the run’s model directory instead of passing parent paths around
- Update load_run_config and load_merge_configs to use new path helpers and improve assertion/log messages
- Replace internal _get_shared_wg_path("results") usages with get_path_run() in wegen_reader and train_logger
Description
Issue Number
Closes #1061
Is this PR a draft? Mark it as draft.
Checklist before asking for review
./scripts/actions.sh lint./scripts/actions.sh unit-test./scripts/actions.sh integration-testlaunch-slurm.py --time 60