From 98be78cdb582e078bfc1667cd89b5b27fb7aa83f Mon Sep 17 00:00:00 2001 From: Gokul A Date: Mon, 13 Oct 2025 15:50:22 -0700 Subject: [PATCH] Update instance type regex to also include hyphens --- src/sagemaker/estimator.py | 2 +- src/sagemaker/fw_utils.py | 8 ++++---- src/sagemaker/serve/utils/optimize_utils.py | 2 +- src/sagemaker/utils.py | 2 +- .../sagemaker/serve/utils/test_optimize_utils.py | 2 ++ tests/unit/test_estimator.py | 15 +++++++++++++++ tests/unit/test_fw_utils.py | 11 +++++++++++ tests/unit/test_utils.py | 1 + 8 files changed, 36 insertions(+), 7 deletions(-) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 8cd6410ea0..2d8318fd39 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -2119,7 +2119,7 @@ def _get_instance_type(self): instance_type = instance_group.instance_type if is_pipeline_variable(instance_type): continue - match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type) + match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type) if match: family = match[1] diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 4a00b2dbc1..42e55eede8 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -962,7 +962,7 @@ def validate_distribution_for_instance_type(instance_type, distribution): """ err_msg = "" if isinstance(instance_type, str): - match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type) + match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type) if match and match[1].startswith("trn"): keys = list(distribution.keys()) if len(keys) == 0: @@ -1083,7 +1083,7 @@ def _is_gpu_instance(instance_type): bool: Whether or not the instance_type supports GPU """ if isinstance(instance_type, str): - match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type) + match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type) if match: if match[1].startswith("p") or match[1].startswith("g"): return True @@ -1102,7 +1102,7 @@ def _is_trainium_instance(instance_type): bool: Whether or not the instance_type is a Trainium instance """ if isinstance(instance_type, str): - match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type) + match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type) if match and match[1].startswith("trn"): return True return False @@ -1149,7 +1149,7 @@ def _instance_type_supports_profiler(instance_type): bool: Whether or not the region supports Amazon SageMaker Debugger profiling feature. """ if isinstance(instance_type, str): - match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type) + match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type) if match and match[1].startswith("trn"): return True return False diff --git a/src/sagemaker/serve/utils/optimize_utils.py b/src/sagemaker/serve/utils/optimize_utils.py index 68ed1e846d..7b36f0cf87 100644 --- a/src/sagemaker/serve/utils/optimize_utils.py +++ b/src/sagemaker/serve/utils/optimize_utils.py @@ -38,7 +38,7 @@ def _is_inferentia_or_trainium(instance_type: Optional[str]) -> bool: bool: Whether the given instance type is Inferentia or Trainium. """ if isinstance(instance_type, str): - match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type) + match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type) if match: if match[1].startswith("inf") or match[1].startswith("trn"): return True diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index af3cc16f1e..33744bd455 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -1529,7 +1529,7 @@ def get_instance_type_family(instance_type: str) -> str: """ instance_type_family = "" if isinstance(instance_type, str): - match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type) + match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type) if match is not None: instance_type_family = match[1] return instance_type_family diff --git a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py index b392b255da..184393d6f1 100644 --- a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py +++ b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py @@ -95,6 +95,8 @@ [ ("ml.trn1.2xlarge", True), ("ml.inf2.xlarge", True), + ("ml.trn1-n.2xlarge", True), + ("ml.inf2-b.xlarge", True), ("ml.c7gd.4xlarge", False), ], ) diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 1698da3e90..c953b2ffd5 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -2246,6 +2246,21 @@ def test_get_instance_type_gpu(sagemaker_session): assert "ml.p3.16xlarge" == estimator._get_instance_type() +def test_get_instance_type_gpu_with_hyphens(sagemaker_session): + estimator = Estimator( + image_uri="some-image", + role="some_image", + instance_groups=[ + InstanceGroup("group1", "ml.c4.xlarge", 1), + InstanceGroup("group2", "ml.p6-b200.48xlarge", 2), + ], + sagemaker_session=sagemaker_session, + base_job_name="base_job_name", + ) + + assert "ml.p6-b200.48xlarge" == estimator._get_instance_type() + + def test_estimator_with_output_compression_disabled(sagemaker_session): estimator = Estimator( image_uri="some-image", diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index 97d4e6ec2a..065630f500 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -1065,6 +1065,13 @@ def test_validate_unsupported_distributions_trainium_raises(): instance_type="ml.trn1.32xlarge", ) + with pytest.raises(ValueError): + mpi_enabled = {"mpi": {"enabled": True}} + fw_utils.validate_distribution_for_instance_type( + distribution=mpi_enabled, + instance_type="ml.trn1-n.2xlarge", + ) + with pytest.raises(ValueError): pytorch_ddp_enabled = {"pytorch_ddp": {"enabled": True}} fw_utils.validate_distribution_for_instance_type( @@ -1082,6 +1089,7 @@ def test_validate_unsupported_distributions_trainium_raises(): def test_instance_type_supports_profiler(): assert fw_utils._instance_type_supports_profiler("ml.trn1.xlarge") is True + assert fw_utils._instance_type_supports_profiler("ml.trn1-n.xlarge") is True assert fw_utils._instance_type_supports_profiler("ml.m4.xlarge") is False assert fw_utils._instance_type_supports_profiler("local") is False @@ -1097,6 +1105,8 @@ def test_is_gpu_instance(): "ml.g4dn.xlarge", "ml.g5.xlarge", "ml.g5.48xlarge", + "ml.p6-b200.48xlarge", + "ml.g6e-12xlarge.xlarge", "local_gpu", ] non_gpu_instance_types = [ @@ -1116,6 +1126,7 @@ def test_is_trainium_instance(): trainium_instance_types = [ "ml.trn1.2xlarge", "ml.trn1.32xlarge", + "ml.trn1-n.2xlarge", ] non_trainum_instance_types = [ "ml.t3.xlarge", diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index f243bf1635..5deff5163b 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1844,6 +1844,7 @@ def test_instance_family_from_full_instance_type(self): "ml.afbsadjfbasfb.sdkjfnsa": "afbsadjfbasfb", "ml_fdsfsdf.xlarge": "fdsfsdf", "ml_c2.4xlarge": "c2", + "ml.p6-b200.48xlarge": "p6-b200", "sdfasfdda": "", "local": "", "c2.xlarge": "",