2020MODEL_DATA = "s3://bucket/model.tar.gz"
2121MODEL_IMAGE = "mi"
2222
23+ IMAGE_URI = "inference-container-uri"
24+
2325REGION = "us-west-2"
2426
2527NEO_REGION_ACCOUNT = "301217895009"
2628DESCRIBE_COMPILATION_JOB_RESPONSE = {
2729 "CompilationJobStatus" : "Completed" ,
2830 "ModelArtifacts" : {"S3ModelArtifacts" : "s3://output-path/model.tar.gz" },
31+ "InferenceImage" : IMAGE_URI ,
2932}
3033
3134
@@ -52,12 +55,7 @@ def test_compile_model_for_inferentia(sagemaker_session):
5255 framework_version = "1.15.0" ,
5356 job_name = "compile-model" ,
5457 )
55- assert (
56- "{}.dkr.ecr.{}.amazonaws.com/sagemaker-neo-tensorflow:1.15.0-inf-py3" .format (
57- NEO_REGION_ACCOUNT , REGION
58- )
59- == model .image_uri
60- )
58+ assert DESCRIBE_COMPILATION_JOB_RESPONSE ["InferenceImage" ] == model .image_uri
6159 assert model ._is_compiled_model is True
6260
6361
@@ -271,11 +269,12 @@ def test_deploy_add_compiled_model_suffix_to_endpoint_name_from_model_name(sagem
271269 assert model .endpoint_name .startswith ("{}-ml-c4" .format (model_name ))
272270
273271
274- @patch ("sagemaker.session.Session" )
275- def test_compile_with_framework_version_15 (session ):
276- session .return_value .boto_region_name = REGION
272+ def test_compile_with_framework_version_15 (sagemaker_session ):
273+ sagemaker_session .wait_for_compilation_job = Mock (
274+ return_value = DESCRIBE_COMPILATION_JOB_RESPONSE
275+ )
277276
278- model = _create_model ()
277+ model = _create_model (sagemaker_session )
279278 model .compile (
280279 target_instance_family = "ml_c4" ,
281280 input_shape = {"data" : [1 , 3 , 1024 , 1024 ]},
@@ -286,14 +285,15 @@ def test_compile_with_framework_version_15(session):
286285 job_name = "compile-model" ,
287286 )
288287
289- assert "1.5" in model .image_uri
288+ assert IMAGE_URI == model .image_uri
290289
291290
292- @patch ("sagemaker.session.Session" )
293- def test_compile_with_framework_version_16 (session ):
294- session .return_value .boto_region_name = REGION
291+ def test_compile_with_framework_version_16 (sagemaker_session ):
292+ sagemaker_session .wait_for_compilation_job = Mock (
293+ return_value = DESCRIBE_COMPILATION_JOB_RESPONSE
294+ )
295295
296- model = _create_model ()
296+ model = _create_model (sagemaker_session )
297297 model .compile (
298298 target_instance_family = "ml_c4" ,
299299 input_shape = {"data" : [1 , 3 , 1024 , 1024 ]},
@@ -304,26 +304,7 @@ def test_compile_with_framework_version_16(session):
304304 job_name = "compile-model" ,
305305 )
306306
307- assert "1.6" in model .image_uri
308-
309-
310- @patch ("sagemaker.session.Session" )
311- def test_compile_validates_framework_version (session ):
312- session .return_value .boto_region_name = REGION
313-
314- model = _create_model ()
315- with pytest .raises (ValueError ) as e :
316- model .compile (
317- target_instance_family = "ml_c4" ,
318- input_shape = {"data" : [1 , 3 , 1024 , 1024 ]},
319- output_path = "s3://output" ,
320- role = "role" ,
321- framework = "pytorch" ,
322- framework_version = "1.6.1" ,
323- job_name = "compile-model" ,
324- )
325-
326- assert "Unsupported neo-pytorch version: 1.6.1." in str (e )
307+ assert IMAGE_URI == model .image_uri
327308
328309
329310@patch ("sagemaker.session.Session" )
@@ -347,3 +328,25 @@ def test_compile_with_pytorch_neo_in_ml_inf(session):
347328 )
348329 != model .image_uri
349330 )
331+
332+
333+ def test_compile_validates_framework_version (sagemaker_session ):
334+ sagemaker_session .wait_for_compilation_job = Mock (
335+ return_value = {
336+ "CompilationJobStatus" : "Completed" ,
337+ "ModelArtifacts" : {"S3ModelArtifacts" : "s3://output-path/model.tar.gz" },
338+ "InferenceImage" : None ,
339+ }
340+ )
341+ model = _create_model (sagemaker_session )
342+ model .compile (
343+ target_instance_family = "ml_c4" ,
344+ input_shape = {"data" : [1 , 3 , 1024 , 1024 ]},
345+ output_path = "s3://output" ,
346+ role = "role" ,
347+ framework = "pytorch" ,
348+ framework_version = "1.6.1" ,
349+ job_name = "compile-model" ,
350+ )
351+
352+ assert model .image_uri is None
0 commit comments