3030]
3131
3232
33- def _test_graviton_framework_uris (framework , version , py_version , account , region ):
33+ def _test_graviton_framework_uris (
34+ framework , version , py_version , account , region , container_version = "ubuntu20.04-sagemaker"
35+ ):
3436 for instance_type in GRAVITON_INSTANCE_TYPES :
3537 uri = image_uris .retrieve (framework , region , instance_type = instance_type , version = version )
3638 expected = _expected_graviton_framework_uri (
37- framework , version , py_version , account , region = region
39+ framework ,
40+ version ,
41+ py_version ,
42+ account ,
43+ region = region ,
44+ container_version = container_version ,
3845 )
3946 assert expected == uri
4047
@@ -50,11 +57,21 @@ def test_graviton_framework_uris(load_config_and_file_name, scope):
5057 for version in VERSIONS :
5158 ACCOUNTS = config [scope ]["versions" ][version ]["registries" ]
5259 py_versions = config [scope ]["versions" ][version ]["py_versions" ]
60+ container_version = (
61+ config [scope ]["versions" ][version ].get ("container_version" , {}).get ("cpu" , None )
62+ )
63+ if container_version :
64+ container_version = container_version + "-sagemaker"
5365 for py_version in py_versions :
5466 for region in ACCOUNTS .keys ():
55- _test_graviton_framework_uris (
56- framework , version , py_version , ACCOUNTS [region ], region
57- )
67+ if container_version :
68+ _test_graviton_framework_uris (
69+ framework , version , py_version , ACCOUNTS [region ], region , container_version
70+ )
71+ else :
72+ _test_graviton_framework_uris (
73+ framework , version , py_version , ACCOUNTS [region ], region
74+ )
5875
5976
6077def _test_graviton_unsupported_framework (framework , region , framework_version ):
@@ -183,11 +200,14 @@ def test_graviton_sklearn_image_scope_specified_x86_instance(graviton_sklearn_un
183200 assert "Unsupported instance type: m5." in str (error )
184201
185202
186- def _expected_graviton_framework_uri (framework , version , py_version , account , region ):
203+ def _expected_graviton_framework_uri (
204+ framework , version , py_version , account , region , container_version
205+ ):
187206 return expected_uris .graviton_framework_uri (
188207 "{}-inference-graviton" .format (framework ),
189208 fw_version = version ,
190209 py_version = py_version ,
191210 account = account ,
192211 region = region ,
212+ container_version = container_version ,
193213 )
0 commit comments