Skip to content

Commit 0ec5d5e

Browse files
authored
Allow aliases for run_opt
Differential Revision: D84157870 Pull Request resolved: #1141
1 parent 6c03c7e commit 0ec5d5e

File tree

2 files changed

+111
-6
lines changed

2 files changed

+111
-6
lines changed

torchx/specs/api.py

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -891,10 +891,14 @@ class runopt:
891891
Represents the metadata about the specific run option
892892
"""
893893

894+
class alias(str):
895+
pass
896+
894897
default: CfgVal
895898
opt_type: Type[CfgVal]
896899
is_required: bool
897900
help: str
901+
aliases: list[alias] | None = None
898902

899903
@property
900904
def is_type_list_of_str(self) -> bool:
@@ -986,6 +990,7 @@ class runopts:
986990

987991
def __init__(self) -> None:
988992
self._opts: Dict[str, runopt] = {}
993+
self._alias_to_key: dict[runopt.alias, str] = {}
989994

990995
def __iter__(self) -> Iterator[Tuple[str, runopt]]:
991996
return self._opts.items().__iter__()
@@ -1013,9 +1018,16 @@ def is_type(obj: CfgVal, tp: Type[CfgVal]) -> bool:
10131018

10141019
def get(self, name: str) -> Optional[runopt]:
10151020
"""
1016-
Returns option if any was registered, or None otherwise
1021+
Returns option if any was registered, or None otherwise.
1022+
First searches for the option by ``name``, then falls-back to matching ``name`` with any
1023+
registered aliases.
1024+
10171025
"""
1018-
return self._opts.get(name, None)
1026+
if name in self._opts:
1027+
return self._opts[name]
1028+
if name in self._alias_to_key:
1029+
return self._opts[self._alias_to_key[name]]
1030+
return None
10191031

10201032
def resolve(self, cfg: Mapping[str, CfgVal]) -> Dict[str, CfgVal]:
10211033
"""
@@ -1030,6 +1042,24 @@ def resolve(self, cfg: Mapping[str, CfgVal]) -> Dict[str, CfgVal]:
10301042

10311043
for cfg_key, runopt in self._opts.items():
10321044
val = resolved_cfg.get(cfg_key)
1045+
resolved_name = None
1046+
aliases = runopt.aliases or []
1047+
if val is None:
1048+
for alias in aliases:
1049+
val = resolved_cfg.get(alias)
1050+
if alias in cfg or val is not None:
1051+
resolved_name = alias
1052+
break
1053+
else:
1054+
resolved_name = cfg_key
1055+
for alias in aliases:
1056+
duplicate_val = resolved_cfg.get(alias)
1057+
if alias in cfg or duplicate_val is not None:
1058+
raise InvalidRunConfigException(
1059+
f"Duplicate opt name. runopt: `{resolved_name}``, is an alias of runopt: `{alias}`",
1060+
resolved_name,
1061+
cfg,
1062+
)
10331063

10341064
# check required opt
10351065
if runopt.is_required and val is None:
@@ -1049,7 +1079,7 @@ def resolve(self, cfg: Mapping[str, CfgVal]) -> Dict[str, CfgVal]:
10491079
)
10501080

10511081
# not required and not set, set to default
1052-
if val is None:
1082+
if val is None and resolved_name is None:
10531083
resolved_cfg[cfg_key] = runopt.default
10541084
return resolved_cfg
10551085

@@ -1142,9 +1172,38 @@ def cfg_from_json_repr(self, json_repr: str) -> Dict[str, CfgVal]:
11421172
cfg[key] = val
11431173
return cfg
11441174

1175+
def _get_primary_key_and_aliases(
1176+
self,
1177+
cfg_key: list[str] | str,
1178+
) -> tuple[str, list[runopt.alias]]:
1179+
"""
1180+
Returns the primary key and aliases for the given cfg_key.
1181+
"""
1182+
if isinstance(cfg_key, str):
1183+
return cfg_key, []
1184+
1185+
if len(cfg_key) == 0:
1186+
raise ValueError("cfg_key must be a non-empty list")
1187+
primary_key = None
1188+
aliases = list[runopt.alias]()
1189+
for name in cfg_key:
1190+
if isinstance(name, runopt.alias):
1191+
aliases.append(name)
1192+
else:
1193+
if primary_key is not None:
1194+
raise ValueError(
1195+
f" Given more than one primary key: {primary_key}, {name}. Please use runopt.alias type for aliases. "
1196+
)
1197+
primary_key = name
1198+
if primary_key is None or primary_key == "":
1199+
raise ValueError(
1200+
"Missing cfg_key. Please provide one other than the aliases."
1201+
)
1202+
return primary_key, aliases
1203+
11451204
def add(
11461205
self,
1147-
cfg_key: str,
1206+
cfg_key: str | list[str],
11481207
type_: Type[CfgVal],
11491208
help: str,
11501209
default: CfgVal = None,
@@ -1155,6 +1214,7 @@ def add(
11551214
value (if any). If the ``default`` is not specified then this option
11561215
is a required option.
11571216
"""
1217+
primary_key, aliases = self._get_primary_key_and_aliases(cfg_key)
11581218
if required and default is not None:
11591219
raise ValueError(
11601220
f"Required option: {cfg_key} must not specify default value. Given: {default}"
@@ -1165,8 +1225,10 @@ def add(
11651225
f"Option: {cfg_key}, must be of type: {type_}."
11661226
f" Given: {default} ({type(default).__name__})"
11671227
)
1168-
1169-
self._opts[cfg_key] = runopt(default, type_, required, help)
1228+
opt = runopt(default, type_, required, help, aliases)
1229+
for alias in aliases:
1230+
self._alias_to_key[alias] = primary_key
1231+
self._opts[primary_key] = opt
11701232

11711233
def update(self, other: "runopts") -> None:
11721234
self._opts.update(other._opts)

torchx/specs/test/api_test.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,49 @@ def test_runopts_add(self) -> None:
578578
# this print is intentional (demonstrates the intended usecase)
579579
print(opts)
580580

581+
def test_runopts_add_with_aliases(self) -> None:
582+
opts = runopts()
583+
opts.add(
584+
["job_priority", runopt.alias("jobPriority")],
585+
type_=str,
586+
help="priority for the job",
587+
)
588+
self.assertEqual(1, len(opts._opts))
589+
self.assertIsNotNone(opts.get("job_priority"))
590+
self.assertIsNotNone(opts.get("jobPriority"))
591+
592+
def test_runopts_resolve_with_aliases(self) -> None:
593+
opts = runopts()
594+
opts.add(
595+
["job_priority", runopt.alias("jobPriority")],
596+
type_=str,
597+
help="priority for the job",
598+
)
599+
opts.resolve({"job_priority": "high"})
600+
opts.resolve({"jobPriority": "low"})
601+
with self.assertRaises(InvalidRunConfigException):
602+
opts.resolve({"job_priority": "high", "jobPriority": "low"})
603+
604+
def test_runopts_resolve_with_none_valued_aliases(self) -> None:
605+
opts = runopts()
606+
opts.add(
607+
["job_priority", runopt.alias("jobPriority")],
608+
type_=str,
609+
help="priority for the job",
610+
)
611+
opts.add(
612+
["modelTypeName", runopt.alias("model_type_name")],
613+
type_=Union[str, None],
614+
help="ML Hub Model Type to attribute resource utilization for job",
615+
)
616+
resolved_opts = opts.resolve({"model_type_name": None, "jobPriority": "low"})
617+
self.assertEqual(resolved_opts.get("model_type_name"), None)
618+
self.assertEqual(resolved_opts.get("jobPriority"), "low")
619+
self.assertEqual(resolved_opts, {"model_type_name": None, "jobPriority": "low"})
620+
621+
with self.assertRaises(InvalidRunConfigException):
622+
opts.resolve({"model_type_name": None, "modelTypeName": "low"})
623+
581624
def get_runopts(self) -> runopts:
582625
opts = runopts()
583626
opts.add("run_as", type_=str, help="run as user", required=True)

0 commit comments

Comments
 (0)