Skip to content

Commit c7cdf3b

Browse files
author
Roja Reddy Sareddy
committed
Update image name tests to address new region
1 parent 0976e74 commit c7cdf3b

File tree

4 files changed

+20
-0
lines changed

4 files changed

+20
-0
lines changed

tests/unit/sagemaker/image_uris/expected_uris.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,10 @@ def sagemaker_distribution_uri(repo, account, tag, processor, region=REGION):
116116
else:
117117
tag = f"{tag}-gpu"
118118
return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag)
119+
120+
def get_special_region_domain(region):
121+
SPECIAL_REGIONS = {
122+
"eu-isoe-west-1": ".cloud.adc-e.uk",
123+
"eusc-de-east-1": ".amazonaws.eu"
124+
}
125+
return SPECIAL_REGIONS.get(region, ".amazonaws.com")

tests/unit/sagemaker/image_uris/test_dlc_frameworks.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,11 @@ def test_dlc_framework_uris(load_config_and_file_name, scope):
107107
region=region,
108108
account=ACCOUNTS[region],
109109
)
110+
# Handle special regions
111+
domain = expected_uris.get_special_region_domain(region)
112+
if domain != ".amazonaws.com":
113+
expected = expected.replace(".amazonaws.com", domain)
114+
110115
assert uri == expected
111116

112117

tests/unit/sagemaker/image_uris/test_graviton.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ def _test_graviton_framework_uris(
4343
region=region,
4444
container_version=container_version,
4545
)
46+
# Handle special regions
47+
domain = expected_uris.get_special_region_domain(region)
48+
if domain != ".amazonaws.com":
49+
expected = expected.replace(".amazonaws.com", domain)
4650
assert expected == uri
4751

4852

tests/unit/sagemaker/image_uris/test_huggingface_llm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ def test_huggingface_uris(load_config):
111111
HF_VERSIONS_MAPPING[device][version],
112112
region=region,
113113
)
114+
# Handle special regions
115+
domain = expected_uris.get_special_region_domain(region)
116+
if domain != ".amazonaws.com":
117+
expected = expected.replace(".amazonaws.com", domain)
114118
assert expected == uri
115119

116120

0 commit comments

Comments
 (0)