Skip to content

Commit 3ac3a1e

Browse files
AbishekSfacebook-github-bot
authored andcommitted
Use auto aliasing for cases for the runopt
Summary: Introduces enum for auto aliasing based on casing for the runopt nam ``` class AutoAlias(IntEnum): snake_case = 0x1 SNAKE_CASE = 0x2 camelCase = 0x4 ``` So user can extend name to be used as ``` opts.add( ["job_priority", runopt.AutoAlias.camelCase], type_=str, help="run as user", ) opts.add( [ "model_type_name", runopt.AutoAlias.camelCase | runopt.AutoAlias.SNAKE_CASE, ], type_=str, help="run as user", ) ``` This should automatically produce additional aliases of `jobPriority` to `job_priority` and produce `modelTypeName` and `MODEL_TYPE_NAME` for `model_type_name` Differential Revision: D84192560
1 parent 695fd17 commit 3ac3a1e

File tree

2 files changed

+62
-11
lines changed

2 files changed

+62
-11
lines changed

torchx/specs/api.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import warnings
1919
from dataclasses import asdict, dataclass, field
2020
from datetime import datetime
21-
from enum import Enum
21+
from enum import Enum, IntEnum
2222
from json import JSONDecodeError
2323
from string import Template
2424
from typing import (
@@ -38,6 +38,8 @@
3838
Union,
3939
)
4040

41+
import stringcase
42+
4143
from torchx.util.types import to_dict
4244

4345
_APP_STATUS_FORMAT_TEMPLATE = """AppStatus:
@@ -893,6 +895,11 @@ class runopt:
893895
Represents the metadata about the specific run option
894896
"""
895897

898+
class AutoAlias(IntEnum):
899+
snake_case = 0x1
900+
SNAKE_CASE = 0x2
901+
camelCase = 0x4
902+
896903
class alias(str):
897904
pass
898905

@@ -903,8 +910,8 @@ class deprecated(str):
903910
opt_type: Type[CfgVal]
904911
is_required: bool
905912
help: str
906-
aliases: list[alias] | None = None
907-
deprecated_aliases: list[deprecated] | None = None
913+
aliases: set[alias] | None = None
914+
deprecated_aliases: set[deprecated] | None = None
908915

909916
@property
910917
def is_type_list_of_str(self) -> bool:
@@ -1190,26 +1197,42 @@ def cfg_from_json_repr(self, json_repr: str) -> Dict[str, CfgVal]:
11901197
cfg[key] = val
11911198
return cfg
11921199

1200+
def _generate_aliases(
1201+
self, auto_alias: int, aliases: set[str]
1202+
) -> set[runopt.alias]:
1203+
generated_aliases = set()
1204+
for alias in aliases:
1205+
if auto_alias & runopt.AutoAlias.camelCase:
1206+
generated_aliases.add(stringcase.camelcase(alias))
1207+
if auto_alias & runopt.AutoAlias.snake_case:
1208+
generated_aliases.add(stringcase.snakecase(alias))
1209+
if auto_alias & runopt.AutoAlias.SNAKE_CASE:
1210+
generated_aliases.add(stringcase.constcase(alias))
1211+
return generated_aliases
1212+
11931213
def _get_primary_key_and_aliases(
11941214
self,
1195-
cfg_key: list[str] | str,
1196-
) -> tuple[str, list[runopt.alias], list[runopt.deprecated]]:
1215+
cfg_key: list[str | int] | str,
1216+
) -> tuple[str, set[runopt.alias], set[runopt.deprecated]]:
11971217
"""
11981218
Returns the primary key and aliases for the given cfg_key.
11991219
"""
12001220
if isinstance(cfg_key, str):
1201-
return cfg_key, [], []
1221+
return cfg_key, set(), set()
12021222

12031223
if len(cfg_key) == 0:
12041224
raise ValueError("cfg_key must be a non-empty list")
12051225
primary_key = None
1206-
aliases = list[runopt.alias]()
1207-
deprecated_aliases = list[runopt.deprecated]()
1226+
auto_alias = 0x0
1227+
aliases = set[runopt.alias]()
1228+
deprecated_aliases = set[runopt.deprecated]()
12081229
for name in cfg_key:
12091230
if isinstance(name, runopt.alias):
1210-
aliases.append(name)
1231+
aliases.add(name)
12111232
elif isinstance(name, runopt.deprecated):
1212-
deprecated_aliases.append(name)
1233+
deprecated_aliases.add(name)
1234+
elif isinstance(name, int):
1235+
auto_alias = auto_alias | name
12131236
else:
12141237
if primary_key is not None:
12151238
raise ValueError(
@@ -1220,11 +1243,17 @@ def _get_primary_key_and_aliases(
12201243
raise ValueError(
12211244
"Missing cfg_key. Please provide one other than the aliases."
12221245
)
1246+
if auto_alias != 0x0:
1247+
aliases_to_generate_for = aliases | {primary_key}
1248+
additional_aliases = self._generate_aliases(
1249+
auto_alias, aliases_to_generate_for
1250+
)
1251+
aliases.update(additional_aliases)
12231252
return primary_key, aliases, deprecated_aliases
12241253

12251254
def add(
12261255
self,
1227-
cfg_key: str | list[str],
1256+
cfg_key: str | list[str | int],
12281257
type_: Type[CfgVal],
12291258
help: str,
12301259
default: CfgVal = None,

torchx/specs/test/api_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,28 @@ def test_runopts_add_with_deprecated_aliases(self) -> None:
640640
"Run option: jobPriority, is deprecated. Please use job_priority instead",
641641
)
642642

643+
def test_runopt_auto_aliases(self) -> None:
644+
opts = runopts()
645+
opts.add(
646+
["job_priority", runopt.AutoAlias.camelCase],
647+
type_=str,
648+
help="run as user",
649+
)
650+
opts.add(
651+
[
652+
"model_type_name",
653+
runopt.AutoAlias.camelCase | runopt.AutoAlias.SNAKE_CASE,
654+
],
655+
type_=str,
656+
help="run as user",
657+
)
658+
self.assertEqual(2, len(opts._opts))
659+
self.assertIsNotNone(opts.get("job_priority"))
660+
self.assertIsNotNone(opts.get("jobPriority"))
661+
self.assertIsNotNone(opts.get("model_type_name"))
662+
self.assertIsNotNone(opts.get("modelTypeName"))
663+
self.assertIsNotNone(opts.get("MODEL_TYPE_NAME"))
664+
643665
def get_runopts(self) -> runopts:
644666
opts = runopts()
645667
opts.add("run_as", type_=str, help="run as user", required=True)

0 commit comments

Comments
 (0)