1616import pytest
1717import sagemaker
1818import os
19+ import warnings
1920
2021from mock import (
2122 Mock ,
6364)
6465from tests .unit import DATA_DIR
6566
66- SCRIPT_FILE = "dummy_script.py"
67- SCRIPT_PATH = os .path .join (DATA_DIR , SCRIPT_FILE )
67+ DUMMY_SCRIPT_PATH = os .path .join (DATA_DIR , "dummy_script.py" )
6868
6969REGION = "us-west-2"
7070BUCKET = "my-bucket"
@@ -129,6 +129,31 @@ def sagemaker_session(boto_session, client):
129129 )
130130
131131
132+ @pytest .fixture
133+ def script_processor (sagemaker_session ):
134+ return ScriptProcessor (
135+ role = ROLE ,
136+ image_uri = "012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri" ,
137+ command = ["python3" ],
138+ instance_type = "ml.m4.xlarge" ,
139+ instance_count = 1 ,
140+ volume_size_in_gb = 100 ,
141+ volume_kms_key = "arn:aws:kms:us-west-2:012345678901:key/volume-kms-key" ,
142+ output_kms_key = "arn:aws:kms:us-west-2:012345678901:key/output-kms-key" ,
143+ max_runtime_in_seconds = 3600 ,
144+ base_job_name = "my_sklearn_processor" ,
145+ env = {"my_env_variable" : "my_env_variable_value" },
146+ tags = [{"Key" : "my-tag" , "Value" : "my-tag-value" }],
147+ network_config = NetworkConfig (
148+ subnets = ["my_subnet_id" ],
149+ security_group_ids = ["my_security_group_id" ],
150+ enable_network_isolation = True ,
151+ encrypt_inter_container_traffic = True ,
152+ ),
153+ sagemaker_session = sagemaker_session ,
154+ )
155+
156+
132157def test_custom_step ():
133158 step = CustomStep (
134159 name = "MyStep" , display_name = "CustomStepDisplayName" , description = "CustomStepDescription"
@@ -326,7 +351,7 @@ def test_training_step_tensorflow(sagemaker_session):
326351 training_epochs_parameter = ParameterInteger (name = "TrainingEpochs" , default_value = 5 )
327352 training_batch_size_parameter = ParameterInteger (name = "TrainingBatchSize" , default_value = 500 )
328353 estimator = TensorFlow (
329- entry_point = os . path . join ( DATA_DIR , SCRIPT_FILE ) ,
354+ entry_point = DUMMY_SCRIPT_PATH ,
330355 role = ROLE ,
331356 model_dir = False ,
332357 image_uri = IMAGE_URI ,
@@ -403,6 +428,75 @@ def test_training_step_tensorflow(sagemaker_session):
403428 assert step .properties .TrainingJobName .expr == {"Get" : "Steps.MyTrainingStep.TrainingJobName" }
404429
405430
431+ def test_training_step_profiler_warning (sagemaker_session ):
432+ estimator = TensorFlow (
433+ entry_point = DUMMY_SCRIPT_PATH ,
434+ role = ROLE ,
435+ model_dir = False ,
436+ image_uri = IMAGE_URI ,
437+ source_dir = "s3://mybucket/source" ,
438+ framework_version = "2.4.1" ,
439+ py_version = "py37" ,
440+ disable_profiler = False ,
441+ instance_count = 1 ,
442+ instance_type = "ml.p3.16xlarge" ,
443+ sagemaker_session = sagemaker_session ,
444+ hyperparameters = {
445+ "batch-size" : 500 ,
446+ "epochs" : 5 ,
447+ },
448+ debugger_hook_config = False ,
449+ distribution = {"smdistributed" : {"dataparallel" : {"enabled" : True }}},
450+ )
451+
452+ inputs = TrainingInput (s3_data = f"s3://{ BUCKET } /train_manifest" )
453+ cache_config = CacheConfig (enable_caching = True , expire_after = "PT1H" )
454+ with warnings .catch_warnings (record = True ) as w :
455+ TrainingStep (
456+ name = "MyTrainingStep" , estimator = estimator , inputs = inputs , cache_config = cache_config
457+ )
458+ assert len (w ) == 1
459+ assert issubclass (w [- 1 ].category , UserWarning )
460+ assert "Profiling is enabled on the provided estimator" in str (w [- 1 ].message )
461+
462+
463+ def test_training_step_no_profiler_warning (sagemaker_session ):
464+ estimator = TensorFlow (
465+ entry_point = DUMMY_SCRIPT_PATH ,
466+ role = ROLE ,
467+ model_dir = False ,
468+ image_uri = IMAGE_URI ,
469+ source_dir = "s3://mybucket/source" ,
470+ framework_version = "2.4.1" ,
471+ py_version = "py37" ,
472+ disable_profiler = True ,
473+ instance_count = 1 ,
474+ instance_type = "ml.p3.16xlarge" ,
475+ sagemaker_session = sagemaker_session ,
476+ hyperparameters = {
477+ "batch-size" : 500 ,
478+ "epochs" : 5 ,
479+ },
480+ debugger_hook_config = False ,
481+ distribution = {"smdistributed" : {"dataparallel" : {"enabled" : True }}},
482+ )
483+
484+ inputs = TrainingInput (s3_data = f"s3://{ BUCKET } /train_manifest" )
485+ cache_config = CacheConfig (enable_caching = True , expire_after = "PT1H" )
486+ with warnings .catch_warnings (record = True ) as w :
487+ # profiler disabled, cache config not None
488+ TrainingStep (
489+ name = "MyTrainingStep" , estimator = estimator , inputs = inputs , cache_config = cache_config
490+ )
491+ assert len (w ) == 0
492+
493+ with warnings .catch_warnings (record = True ) as w :
494+ # profiler enabled, cache config is None
495+ estimator .disable_profiler = False
496+ TrainingStep (name = "MyTrainingStep" , estimator = estimator , inputs = inputs , cache_config = None )
497+ assert len (w ) == 0
498+
499+
406500def test_processing_step (sagemaker_session ):
407501 processing_input_data_uri_parameter = ParameterString (
408502 name = "ProcessingInputDataUri" , default_value = f"s3://{ BUCKET } /processing_manifest"
@@ -473,28 +567,42 @@ def test_processing_step(sagemaker_session):
473567
474568
475569@patch ("sagemaker.processing.ScriptProcessor._normalize_args" )
476- def test_processing_step_normalizes_args (mock_normalize_args , sagemaker_session ):
477- processor = ScriptProcessor (
478- role = ROLE ,
479- image_uri = "012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri" ,
480- command = ["python3" ],
481- instance_type = "ml.m4.xlarge" ,
482- instance_count = 1 ,
483- volume_size_in_gb = 100 ,
484- volume_kms_key = "arn:aws:kms:us-west-2:012345678901:key/volume-kms-key" ,
485- output_kms_key = "arn:aws:kms:us-west-2:012345678901:key/output-kms-key" ,
486- max_runtime_in_seconds = 3600 ,
487- base_job_name = "my_sklearn_processor" ,
488- env = {"my_env_variable" : "my_env_variable_value" },
489- tags = [{"Key" : "my-tag" , "Value" : "my-tag-value" }],
490- network_config = NetworkConfig (
491- subnets = ["my_subnet_id" ],
492- security_group_ids = ["my_security_group_id" ],
493- enable_network_isolation = True ,
494- encrypt_inter_container_traffic = True ,
495- ),
496- sagemaker_session = sagemaker_session ,
570+ def test_processing_step_normalizes_args_with_local_code (mock_normalize_args , script_processor ):
571+ cache_config = CacheConfig (enable_caching = True , expire_after = "PT1H" )
572+ inputs = [
573+ ProcessingInput (
574+ source = f"s3://{ BUCKET } /processing_manifest" ,
575+ destination = "processing_manifest" ,
576+ )
577+ ]
578+ outputs = [
579+ ProcessingOutput (
580+ source = f"s3://{ BUCKET } /processing_manifest" ,
581+ destination = "processing_manifest" ,
582+ )
583+ ]
584+ step = ProcessingStep (
585+ name = "MyProcessingStep" ,
586+ processor = script_processor ,
587+ code = DUMMY_SCRIPT_PATH ,
588+ inputs = inputs ,
589+ outputs = outputs ,
590+ job_arguments = ["arg1" , "arg2" ],
591+ cache_config = cache_config ,
497592 )
593+ mock_normalize_args .return_value = [step .inputs , step .outputs ]
594+ step .to_request ()
595+ mock_normalize_args .assert_called_with (
596+ job_name = "MyProcessingStep-3e89f0c7e101c356cbedf27d9d27e9db" ,
597+ arguments = step .job_arguments ,
598+ inputs = step .inputs ,
599+ outputs = step .outputs ,
600+ code = step .code ,
601+ )
602+
603+
604+ @patch ("sagemaker.processing.ScriptProcessor._normalize_args" )
605+ def test_processing_step_normalizes_args_with_s3_code (mock_normalize_args , script_processor ):
498606 cache_config = CacheConfig (enable_caching = True , expire_after = "PT1H" )
499607 inputs = [
500608 ProcessingInput (
@@ -510,8 +618,8 @@ def test_processing_step_normalizes_args(mock_normalize_args, sagemaker_session)
510618 ]
511619 step = ProcessingStep (
512620 name = "MyProcessingStep" ,
513- processor = processor ,
514- code = "foo.py " ,
621+ processor = script_processor ,
622+ code = "s3:// foo" ,
515623 inputs = inputs ,
516624 outputs = outputs ,
517625 job_arguments = ["arg1" , "arg2" ],
@@ -520,13 +628,48 @@ def test_processing_step_normalizes_args(mock_normalize_args, sagemaker_session)
520628 mock_normalize_args .return_value = [step .inputs , step .outputs ]
521629 step .to_request ()
522630 mock_normalize_args .assert_called_with (
631+ job_name = None ,
523632 arguments = step .job_arguments ,
524633 inputs = step .inputs ,
525634 outputs = step .outputs ,
526635 code = step .code ,
527636 )
528637
529638
639+ @patch ("sagemaker.processing.ScriptProcessor._normalize_args" )
640+ def test_processing_step_normalizes_args_with_no_code (mock_normalize_args , script_processor ):
641+ cache_config = CacheConfig (enable_caching = True , expire_after = "PT1H" )
642+ inputs = [
643+ ProcessingInput (
644+ source = f"s3://{ BUCKET } /processing_manifest" ,
645+ destination = "processing_manifest" ,
646+ )
647+ ]
648+ outputs = [
649+ ProcessingOutput (
650+ source = f"s3://{ BUCKET } /processing_manifest" ,
651+ destination = "processing_manifest" ,
652+ )
653+ ]
654+ step = ProcessingStep (
655+ name = "MyProcessingStep" ,
656+ processor = script_processor ,
657+ inputs = inputs ,
658+ outputs = outputs ,
659+ job_arguments = ["arg1" , "arg2" ],
660+ cache_config = cache_config ,
661+ )
662+ mock_normalize_args .return_value = [step .inputs , step .outputs ]
663+ step .to_request ()
664+ mock_normalize_args .assert_called_with (
665+ job_name = None ,
666+ arguments = step .job_arguments ,
667+ inputs = step .inputs ,
668+ outputs = step .outputs ,
669+ code = None ,
670+ )
671+
672+
530673def test_create_model_step (sagemaker_session ):
531674 model = Model (
532675 image_uri = IMAGE_URI ,
0 commit comments