|
36 | 36 | ACCELERATOR_TYPE = "ml.eia.medium" |
37 | 37 | IMAGE_NAME = "fakeimage" |
38 | 38 | REGION = "us-west-2" |
39 | | -NEO_REGION_ACCOUNT = "301217895009" |
40 | 39 | MODEL_NAME = "{}-{}".format(MODEL_IMAGE, TIMESTAMP) |
41 | 40 | GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git" |
42 | 41 | BRANCH = "test-branch-git-config" |
|
50 | 49 | CODECOMMIT_BRANCH = "master" |
51 | 50 | REPO_DIR = "/tmp/repo_dir" |
52 | 51 |
|
53 | | -DESCRIBE_COMPILATION_JOB_RESPONSE = { |
54 | | - "CompilationJobStatus": "Completed", |
55 | | - "ModelArtifacts": {"S3ModelArtifacts": "s3://output-path/model.tar.gz"}, |
56 | | -} |
57 | | - |
58 | 52 |
|
59 | 53 | class DummyFrameworkModel(FrameworkModel): |
60 | 54 | def __init__(self, sagemaker_session, **kwargs): |
@@ -237,170 +231,6 @@ def test_deploy_update_endpoint_optional_args(sagemaker_session, tmpdir): |
237 | 231 | sagemaker_session.create_endpoint.assert_not_called() |
238 | 232 |
|
239 | 233 |
|
240 | | -def test_compile_model_for_inferentia(sagemaker_session, tmpdir): |
241 | | - sagemaker_session.wait_for_compilation_job = Mock( |
242 | | - return_value=DESCRIBE_COMPILATION_JOB_RESPONSE |
243 | | - ) |
244 | | - model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) |
245 | | - model.compile( |
246 | | - target_instance_family="ml_inf", |
247 | | - input_shape={"data": [1, 3, 1024, 1024]}, |
248 | | - output_path="s3://output", |
249 | | - role="role", |
250 | | - framework="tensorflow", |
251 | | - framework_version="1.15.0", |
252 | | - job_name="compile-model", |
253 | | - ) |
254 | | - assert ( |
255 | | - "{}.dkr.ecr.{}.amazonaws.com/sagemaker-neo-tensorflow:1.15.0-inf-py3".format( |
256 | | - NEO_REGION_ACCOUNT, REGION |
257 | | - ) |
258 | | - == model.image |
259 | | - ) |
260 | | - assert model._is_compiled_model is True |
261 | | - |
262 | | - |
263 | | -def test_compile_model_for_edge_device(sagemaker_session, tmpdir): |
264 | | - sagemaker_session.wait_for_compilation_job = Mock( |
265 | | - return_value=DESCRIBE_COMPILATION_JOB_RESPONSE |
266 | | - ) |
267 | | - model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) |
268 | | - model.compile( |
269 | | - target_instance_family="deeplens", |
270 | | - input_shape={"data": [1, 3, 1024, 1024]}, |
271 | | - output_path="s3://output", |
272 | | - role="role", |
273 | | - framework="tensorflow", |
274 | | - job_name="compile-model", |
275 | | - ) |
276 | | - assert model._is_compiled_model is False |
277 | | - |
278 | | - |
279 | | -def test_compile_model_for_edge_device_tflite(sagemaker_session, tmpdir): |
280 | | - sagemaker_session.wait_for_compilation_job = Mock( |
281 | | - return_value=DESCRIBE_COMPILATION_JOB_RESPONSE |
282 | | - ) |
283 | | - model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) |
284 | | - model.compile( |
285 | | - target_instance_family="deeplens", |
286 | | - input_shape={"data": [1, 3, 1024, 1024]}, |
287 | | - output_path="s3://output", |
288 | | - role="role", |
289 | | - framework="tflite", |
290 | | - job_name="tflite-compile-model", |
291 | | - ) |
292 | | - assert model._is_compiled_model is False |
293 | | - |
294 | | - |
295 | | -def test_compile_model_for_cloud(sagemaker_session, tmpdir): |
296 | | - sagemaker_session.wait_for_compilation_job = Mock( |
297 | | - return_value=DESCRIBE_COMPILATION_JOB_RESPONSE |
298 | | - ) |
299 | | - model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) |
300 | | - model.compile( |
301 | | - target_instance_family="ml_c4", |
302 | | - input_shape={"data": [1, 3, 1024, 1024]}, |
303 | | - output_path="s3://output", |
304 | | - role="role", |
305 | | - framework="tensorflow", |
306 | | - job_name="compile-model", |
307 | | - ) |
308 | | - assert model._is_compiled_model is True |
309 | | - |
310 | | - |
311 | | -def test_compile_model_for_cloud_tflite(sagemaker_session, tmpdir): |
312 | | - sagemaker_session.wait_for_compilation_job = Mock( |
313 | | - return_value=DESCRIBE_COMPILATION_JOB_RESPONSE |
314 | | - ) |
315 | | - model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) |
316 | | - model.compile( |
317 | | - target_instance_family="ml_c4", |
318 | | - input_shape={"data": [1, 3, 1024, 1024]}, |
319 | | - output_path="s3://output", |
320 | | - role="role", |
321 | | - framework="tflite", |
322 | | - job_name="tflite-compile-model", |
323 | | - ) |
324 | | - assert model._is_compiled_model is True |
325 | | - |
326 | | - |
327 | | -@patch("sagemaker.session.Session") |
328 | | -@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) |
329 | | -def test_compile_creates_session(session): |
330 | | - session.return_value.boto_region_name = "us-west-2" |
331 | | - |
332 | | - model = DummyFrameworkModel(sagemaker_session=None) |
333 | | - model.compile( |
334 | | - target_instance_family="ml_c4", |
335 | | - input_shape={"data": [1, 3, 1024, 1024]}, |
336 | | - output_path="s3://output", |
337 | | - role="role", |
338 | | - framework="tensorflow", |
339 | | - job_name="compile-model", |
340 | | - ) |
341 | | - |
342 | | - assert model.sagemaker_session == session.return_value |
343 | | - |
344 | | - |
345 | | -def test_check_neo_region(sagemaker_session, tmpdir): |
346 | | - sagemaker_session.wait_for_compilation_job = Mock( |
347 | | - return_value=DESCRIBE_COMPILATION_JOB_RESPONSE |
348 | | - ) |
349 | | - model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) |
350 | | - ec2_region_list = [ |
351 | | - "us-east-2", |
352 | | - "us-east-1", |
353 | | - "us-west-1", |
354 | | - "us-west-2", |
355 | | - "ap-east-1", |
356 | | - "ap-south-1", |
357 | | - "ap-northeast-3", |
358 | | - "ap-northeast-2", |
359 | | - "ap-southeast-1", |
360 | | - "ap-southeast-2", |
361 | | - "ap-northeast-1", |
362 | | - "ca-central-1", |
363 | | - "cn-north-1", |
364 | | - "cn-northwest-1", |
365 | | - "eu-central-1", |
366 | | - "eu-west-1", |
367 | | - "eu-west-2", |
368 | | - "eu-west-3", |
369 | | - "eu-north-1", |
370 | | - "sa-east-1", |
371 | | - "us-gov-east-1", |
372 | | - "us-gov-west-1", |
373 | | - ] |
374 | | - neo_support_region = [ |
375 | | - "us-west-1", |
376 | | - "us-west-2", |
377 | | - "us-east-1", |
378 | | - "us-east-2", |
379 | | - "eu-west-1", |
380 | | - "eu-west-2", |
381 | | - "eu-west-3", |
382 | | - "eu-central-1", |
383 | | - "eu-north-1", |
384 | | - "ap-northeast-1", |
385 | | - "ap-northeast-2", |
386 | | - "ap-east-1", |
387 | | - "ap-south-1", |
388 | | - "ap-southeast-1", |
389 | | - "ap-southeast-2", |
390 | | - "sa-east-1", |
391 | | - "ca-central-1", |
392 | | - "me-south-1", |
393 | | - "cn-north-1", |
394 | | - "cn-northwest-1", |
395 | | - "us-gov-west-1", |
396 | | - ] |
397 | | - for region_name in ec2_region_list: |
398 | | - if region_name in neo_support_region: |
399 | | - assert model.check_neo_region(region_name) is True |
400 | | - else: |
401 | | - assert model.check_neo_region(region_name) is False |
402 | | - |
403 | | - |
404 | 234 | @patch("sagemaker.git_utils.git_clone_repo") |
405 | 235 | @patch("sagemaker.model.fw_utils.tar_and_upload_dir") |
406 | 236 | def test_git_support_succeed(tar_and_upload_dir, git_clone_repo, sagemaker_session): |
|
0 commit comments