Skip to content

Commit 11b9e99

Browse files
authored
Use auto aliasing for cases for the runopt (meta-pytorch#1143)
* Allow aliases for run_opt (meta-pytorch#1141) Summary: Lets allow aliases for a runopt. This will give downstream users to have multiple ways of accessing the same runopt. * Introduce new class for `runopt.alias` which is used to expand on adding aliases to a runopt. * Add a new dict to maintain alias to key values that can be used by `opt.get(name)` * Modify add() to accept list as well, build out the aliases list and modify the previously created dict to fill in alias to primary_key values. * Modify resolve() to check if a different alias is already used in cfg i.e if the "jobPriority" and "job_priority" are aliases for the same one, we don't allow for both to be present in the cfg. * Modify get to look at the alias to primary_key dict as well. Reviewed By: kiukchung Differential Revision: D84157870 * Add deprecated_aliases to runopt and add warning (meta-pytorch#1142) Summary: Similar to `runopt.alias` lets introduce and use `runopt.deprecated`. This will warn the user with a`UserWarning` when the user uses that specific name and suggests the primary one instead. Reviewed By: kiukchung Differential Revision: D84180061 * Use auto aliasing for cases for the runopt (meta-pytorch#1143) Summary: Introduces enum for auto aliasing based on casing for the runopt nam ``` class AutoAlias(IntEnum): snake_case = 0x1 SNAKE_CASE = 0x2 camelCase = 0x4 ``` So user can extend name to be used as ``` opts.add( ["job_priority", runopt.AutoAlias.camelCase], type_=str, help="run as user", ) opts.add( [ "model_type_name", runopt.AutoAlias.camelCase | runopt.AutoAlias.SNAKE_CASE, ], type_=str, help="run as user", ) ``` This should automatically produce additional aliases of `jobPriority` to `job_priority` and produce `modelTypeName` and `MODEL_TYPE_NAME` for `model_type_name` Reviewed By: kiukchung Differential Revision: D84192560
1 parent 1d26b39 commit 11b9e99

File tree

2 files changed

+79
-11
lines changed

2 files changed

+79
-11
lines changed

torchx/specs/api.py

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import warnings
1919
from dataclasses import asdict, dataclass, field
2020
from datetime import datetime
21-
from enum import Enum
21+
from enum import Enum, IntEnum
2222
from json import JSONDecodeError
2323
from string import Template
2424
from typing import (
@@ -892,6 +892,30 @@ class runopt:
892892
Represents the metadata about the specific run option
893893
"""
894894

895+
class AutoAlias(IntEnum):
896+
snake_case = 0x1
897+
SNAKE_CASE = 0x2
898+
camelCase = 0x4
899+
900+
@staticmethod
901+
def convert_to_camel_case(alias: str) -> str:
902+
words = re.split(r"[_\-\s]+|(?<=[a-z])(?=[A-Z])", alias)
903+
words = [w for w in words if w] # Remove empty strings
904+
if not words:
905+
return ""
906+
return words[0].lower() + "".join(w.capitalize() for w in words[1:])
907+
908+
@staticmethod
909+
def convert_to_snake_case(alias: str) -> str:
910+
alias = re.sub(r"[-\s]+", "_", alias)
911+
alias = re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", alias)
912+
alias = re.sub(r"([A-Z]+)([A-Z][a-z])", r"\1_\2", alias)
913+
return alias.lower()
914+
915+
@staticmethod
916+
def convert_to_const_case(alias: str) -> str:
917+
return runopt.AutoAlias.convert_to_snake_case(alias).upper()
918+
895919
class alias(str):
896920
pass
897921

@@ -902,8 +926,8 @@ class deprecated(str):
902926
opt_type: Type[CfgVal]
903927
is_required: bool
904928
help: str
905-
aliases: list[alias] | None = None
906-
deprecated_aliases: list[deprecated] | None = None
929+
aliases: set[alias] | None = None
930+
deprecated_aliases: set[deprecated] | None = None
907931

908932
@property
909933
def is_type_list_of_str(self) -> bool:
@@ -1189,15 +1213,28 @@ def cfg_from_json_repr(self, json_repr: str) -> Dict[str, CfgVal]:
11891213
cfg[key] = val
11901214
return cfg
11911215

1216+
def _generate_aliases(
1217+
self, auto_alias: int, aliases: set[str]
1218+
) -> set[runopt.alias]:
1219+
generated_aliases = set()
1220+
for alias in aliases:
1221+
if auto_alias & runopt.AutoAlias.camelCase:
1222+
generated_aliases.add(runopt.AutoAlias.convert_to_camel_case(alias))
1223+
if auto_alias & runopt.AutoAlias.snake_case:
1224+
generated_aliases.add(runopt.AutoAlias.convert_to_snake_case(alias))
1225+
if auto_alias & runopt.AutoAlias.SNAKE_CASE:
1226+
generated_aliases.add(runopt.AutoAlias.convert_to_const_case(alias))
1227+
return generated_aliases
1228+
11921229
def _get_primary_key_and_aliases(
11931230
self,
1194-
cfg_key: list[str] | str,
1195-
) -> tuple[str, list[runopt.alias], list[runopt.deprecated]]:
1231+
cfg_key: list[str | int] | str,
1232+
) -> tuple[str, set[runopt.alias], set[runopt.deprecated]]:
11961233
"""
11971234
Returns the primary key and aliases for the given cfg_key.
11981235
"""
11991236
if isinstance(cfg_key, str):
1200-
return cfg_key, [], []
1237+
return cfg_key, set(), set()
12011238

12021239
if len(cfg_key) == 0:
12031240
raise ValueError("cfg_key must be a non-empty list")
@@ -1211,13 +1248,16 @@ def _get_primary_key_and_aliases(
12111248
stacklevel=2,
12121249
)
12131250
primary_key = None
1214-
aliases = list[runopt.alias]()
1215-
deprecated_aliases = list[runopt.deprecated]()
1251+
auto_alias = 0x0
1252+
aliases = set[runopt.alias]()
1253+
deprecated_aliases = set[runopt.deprecated]()
12161254
for name in cfg_key:
12171255
if isinstance(name, runopt.alias):
1218-
aliases.append(name)
1256+
aliases.add(name)
12191257
elif isinstance(name, runopt.deprecated):
1220-
deprecated_aliases.append(name)
1258+
deprecated_aliases.add(name)
1259+
elif isinstance(name, int):
1260+
auto_alias = auto_alias | name
12211261
else:
12221262
if primary_key is not None:
12231263
raise ValueError(
@@ -1228,11 +1268,17 @@ def _get_primary_key_and_aliases(
12281268
raise ValueError(
12291269
"Missing cfg_key. Please provide one other than the aliases."
12301270
)
1271+
if auto_alias != 0x0:
1272+
aliases_to_generate_for = aliases | {primary_key}
1273+
additional_aliases = self._generate_aliases(
1274+
auto_alias, aliases_to_generate_for
1275+
)
1276+
aliases.update(additional_aliases)
12311277
return primary_key, aliases, deprecated_aliases
12321278

12331279
def add(
12341280
self,
1235-
cfg_key: str | list[str],
1281+
cfg_key: str | list[str | int],
12361282
type_: Type[CfgVal],
12371283
help: str,
12381284
default: CfgVal = None,

torchx/specs/test/api_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,28 @@ def test_runopts_add_with_deprecated_aliases(self) -> None:
648648
"Run option `jobPriority` is deprecated, use `job_priority` instead",
649649
)
650650

651+
def test_runopt_auto_aliases(self) -> None:
652+
opts = runopts()
653+
opts.add(
654+
["job_priority", runopt.AutoAlias.camelCase],
655+
type_=str,
656+
help="run as user",
657+
)
658+
opts.add(
659+
[
660+
"model_type_name",
661+
runopt.AutoAlias.camelCase | runopt.AutoAlias.SNAKE_CASE,
662+
],
663+
type_=str,
664+
help="run as user",
665+
)
666+
self.assertEqual(2, len(opts._opts))
667+
self.assertIsNotNone(opts.get("job_priority"))
668+
self.assertIsNotNone(opts.get("jobPriority"))
669+
self.assertIsNotNone(opts.get("model_type_name"))
670+
self.assertIsNotNone(opts.get("modelTypeName"))
671+
self.assertIsNotNone(opts.get("MODEL_TYPE_NAME"))
672+
651673
def get_runopts(self) -> runopts:
652674
opts = runopts()
653675
opts.add("run_as", type_=str, help="run as user", required=True)

0 commit comments

Comments
 (0)