@@ -560,6 +560,22 @@ def test_compile_model_for_edge_device(sagemaker_session, tmpdir):
560560 assert model ._is_compiled_model is False
561561
562562
563+ def test_compile_model_for_edge_device_tflite (sagemaker_session , tmpdir ):
564+ sagemaker_session .wait_for_compilation_job = Mock (
565+ return_value = DESCRIBE_COMPILATION_JOB_RESPONSE
566+ )
567+ model = DummyFrameworkModel (sagemaker_session , source_dir = str (tmpdir ))
568+ model .compile (
569+ target_instance_family = "deeplens" ,
570+ input_shape = {"data" : [1 , 3 , 1024 , 1024 ]},
571+ output_path = "s3://output" ,
572+ role = "role" ,
573+ framework = "tflite" ,
574+ job_name = "tflite-compile-model" ,
575+ )
576+ assert model ._is_compiled_model is False
577+
578+
563579def test_compile_model_for_cloud (sagemaker_session , tmpdir ):
564580 sagemaker_session .wait_for_compilation_job = Mock (
565581 return_value = DESCRIBE_COMPILATION_JOB_RESPONSE
@@ -576,6 +592,22 @@ def test_compile_model_for_cloud(sagemaker_session, tmpdir):
576592 assert model ._is_compiled_model is True
577593
578594
595+ def test_compile_model_for_cloud_tflite (sagemaker_session , tmpdir ):
596+ sagemaker_session .wait_for_compilation_job = Mock (
597+ return_value = DESCRIBE_COMPILATION_JOB_RESPONSE
598+ )
599+ model = DummyFrameworkModel (sagemaker_session , source_dir = str (tmpdir ))
600+ model .compile (
601+ target_instance_family = "ml_c4" ,
602+ input_shape = {"data" : [1 , 3 , 1024 , 1024 ]},
603+ output_path = "s3://output" ,
604+ role = "role" ,
605+ framework = "tflite" ,
606+ job_name = "tflite-compile-model" ,
607+ )
608+ assert model ._is_compiled_model is True
609+
610+
579611@patch ("sagemaker.session.Session" )
580612@patch ("sagemaker.fw_utils.tar_and_upload_dir" , MagicMock ())
581613def test_compile_creates_session (session ):
0 commit comments