diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 45509f65f6..cc42896cf5 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -1457,11 +1457,11 @@ def volume_size_supported(instance_type: str) -> bool: # Any instance type with a "d" in the instance family (i.e. c5d, p4d, etc) # + g5 or g6 or p5 does not support attaching an EBS volume. family = parts[0] - return ( - "d" not in family - and not family.startswith("g5") - and not family.startswith("g6") - and not family.startswith("p5") + + unsupported_families = ["g5", "g6", "p5", "trn1"] + + return "d" not in family and not any( + family.startswith(prefix) for prefix in unsupported_families ) except Exception as e: raise ValueError(f"Failed to parse instance type '{instance_type}': {str(e)}") diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 2f0fd77580..3284d966e2 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1815,6 +1815,8 @@ def test_volume_size_not_supported(self): "local", "local_gpu", ParameterString(name="InstanceType", default_value="ml.m4.xlarge"), + "ml.trn1.32xlarge", + "ml.trn1n.32xlarge", ] for instance in instances_that_dont_support_volume_size: