Skip to content

Commit 824675b

Browse files
authored
Update instance type regex to also include hyphens (#5308)
1 parent 4c8814b commit 824675b

File tree

8 files changed

+36
-7
lines changed

8 files changed

+36
-7
lines changed

src/sagemaker/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2119,7 +2119,7 @@ def _get_instance_type(self):
21192119
instance_type = instance_group.instance_type
21202120
if is_pipeline_variable(instance_type):
21212121
continue
2122-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
2122+
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
21232123

21242124
if match:
21252125
family = match[1]

src/sagemaker/fw_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -962,7 +962,7 @@ def validate_distribution_for_instance_type(instance_type, distribution):
962962
"""
963963
err_msg = ""
964964
if isinstance(instance_type, str):
965-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
965+
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
966966
if match and match[1].startswith("trn"):
967967
keys = list(distribution.keys())
968968
if len(keys) == 0:
@@ -1083,7 +1083,7 @@ def _is_gpu_instance(instance_type):
10831083
bool: Whether or not the instance_type supports GPU
10841084
"""
10851085
if isinstance(instance_type, str):
1086-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1086+
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
10871087
if match:
10881088
if match[1].startswith("p") or match[1].startswith("g"):
10891089
return True
@@ -1102,7 +1102,7 @@ def _is_trainium_instance(instance_type):
11021102
bool: Whether or not the instance_type is a Trainium instance
11031103
"""
11041104
if isinstance(instance_type, str):
1105-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1105+
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
11061106
if match and match[1].startswith("trn"):
11071107
return True
11081108
return False
@@ -1149,7 +1149,7 @@ def _instance_type_supports_profiler(instance_type):
11491149
bool: Whether or not the region supports Amazon SageMaker Debugger profiling feature.
11501150
"""
11511151
if isinstance(instance_type, str):
1152-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1152+
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
11531153
if match and match[1].startswith("trn"):
11541154
return True
11551155
return False

src/sagemaker/serve/utils/optimize_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def _is_inferentia_or_trainium(instance_type: Optional[str]) -> bool:
3838
bool: Whether the given instance type is Inferentia or Trainium.
3939
"""
4040
if isinstance(instance_type, str):
41-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
41+
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
4242
if match:
4343
if match[1].startswith("inf") or match[1].startswith("trn"):
4444
return True

src/sagemaker/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1529,7 +1529,7 @@ def get_instance_type_family(instance_type: str) -> str:
15291529
"""
15301530
instance_type_family = ""
15311531
if isinstance(instance_type, str):
1532-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1532+
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
15331533
if match is not None:
15341534
instance_type_family = match[1]
15351535
return instance_type_family

tests/unit/sagemaker/serve/utils/test_optimize_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@
9595
[
9696
("ml.trn1.2xlarge", True),
9797
("ml.inf2.xlarge", True),
98+
("ml.trn1-n.2xlarge", True),
99+
("ml.inf2-b.xlarge", True),
98100
("ml.c7gd.4xlarge", False),
99101
],
100102
)

tests/unit/test_estimator.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2246,6 +2246,21 @@ def test_get_instance_type_gpu(sagemaker_session):
22462246
assert "ml.p3.16xlarge" == estimator._get_instance_type()
22472247

22482248

2249+
def test_get_instance_type_gpu_with_hyphens(sagemaker_session):
2250+
estimator = Estimator(
2251+
image_uri="some-image",
2252+
role="some_image",
2253+
instance_groups=[
2254+
InstanceGroup("group1", "ml.c4.xlarge", 1),
2255+
InstanceGroup("group2", "ml.p6-b200.48xlarge", 2),
2256+
],
2257+
sagemaker_session=sagemaker_session,
2258+
base_job_name="base_job_name",
2259+
)
2260+
2261+
assert "ml.p6-b200.48xlarge" == estimator._get_instance_type()
2262+
2263+
22492264
def test_estimator_with_output_compression_disabled(sagemaker_session):
22502265
estimator = Estimator(
22512266
image_uri="some-image",

tests/unit/test_fw_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,6 +1065,13 @@ def test_validate_unsupported_distributions_trainium_raises():
10651065
instance_type="ml.trn1.32xlarge",
10661066
)
10671067

1068+
with pytest.raises(ValueError):
1069+
mpi_enabled = {"mpi": {"enabled": True}}
1070+
fw_utils.validate_distribution_for_instance_type(
1071+
distribution=mpi_enabled,
1072+
instance_type="ml.trn1-n.2xlarge",
1073+
)
1074+
10681075
with pytest.raises(ValueError):
10691076
pytorch_ddp_enabled = {"pytorch_ddp": {"enabled": True}}
10701077
fw_utils.validate_distribution_for_instance_type(
@@ -1082,6 +1089,7 @@ def test_validate_unsupported_distributions_trainium_raises():
10821089

10831090
def test_instance_type_supports_profiler():
10841091
assert fw_utils._instance_type_supports_profiler("ml.trn1.xlarge") is True
1092+
assert fw_utils._instance_type_supports_profiler("ml.trn1-n.xlarge") is True
10851093
assert fw_utils._instance_type_supports_profiler("ml.m4.xlarge") is False
10861094
assert fw_utils._instance_type_supports_profiler("local") is False
10871095

@@ -1097,6 +1105,8 @@ def test_is_gpu_instance():
10971105
"ml.g4dn.xlarge",
10981106
"ml.g5.xlarge",
10991107
"ml.g5.48xlarge",
1108+
"ml.p6-b200.48xlarge",
1109+
"ml.g6e-12xlarge.xlarge",
11001110
"local_gpu",
11011111
]
11021112
non_gpu_instance_types = [
@@ -1116,6 +1126,7 @@ def test_is_trainium_instance():
11161126
trainium_instance_types = [
11171127
"ml.trn1.2xlarge",
11181128
"ml.trn1.32xlarge",
1129+
"ml.trn1-n.2xlarge",
11191130
]
11201131
non_trainum_instance_types = [
11211132
"ml.t3.xlarge",

tests/unit/test_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1844,6 +1844,7 @@ def test_instance_family_from_full_instance_type(self):
18441844
"ml.afbsadjfbasfb.sdkjfnsa": "afbsadjfbasfb",
18451845
"ml_fdsfsdf.xlarge": "fdsfsdf",
18461846
"ml_c2.4xlarge": "c2",
1847+
"ml.p6-b200.48xlarge": "p6-b200",
18471848
"sdfasfdda": "",
18481849
"local": "",
18491850
"c2.xlarge": "",

0 commit comments

Comments
 (0)