@@ -726,6 +726,110 @@ def test_framework_with_no_default_profiler_in_unsupported_region(region):
726726 assert args .get ("profiler_rule_configs" ) is None
727727
728728
729+ @pytest .mark .parametrize ("region" , PROFILER_UNSUPPORTED_REGIONS )
730+ def test_framework_with_debugger_config_set_up_in_unsupported_region (region ):
731+ with pytest .raises (ValueError ) as error :
732+ boto_mock = Mock (name = "boto_session" , region_name = region )
733+ sms = MagicMock (
734+ name = "sagemaker_session" ,
735+ boto_session = boto_mock ,
736+ boto_region_name = region ,
737+ config = None ,
738+ local_mode = False ,
739+ s3_client = None ,
740+ s3_resource = None ,
741+ )
742+ f = DummyFramework (
743+ entry_point = SCRIPT_PATH ,
744+ role = ROLE ,
745+ sagemaker_session = sms ,
746+ instance_count = INSTANCE_COUNT ,
747+ instance_type = INSTANCE_TYPE ,
748+ debugger_hook_config = DebuggerHookConfig (s3_output_path = "s3://output" ),
749+ )
750+ f .fit ("s3://mydata" )
751+
752+ assert "Current region does not support debugger but debugger hook config is set!" in str (error )
753+
754+
755+ @pytest .mark .parametrize ("region" , PROFILER_UNSUPPORTED_REGIONS )
756+ def test_framework_enable_profiling_in_unsupported_region (region ):
757+ with pytest .raises (ValueError ) as error :
758+ boto_mock = Mock (name = "boto_session" , region_name = region )
759+ sms = MagicMock (
760+ name = "sagemaker_session" ,
761+ boto_session = boto_mock ,
762+ boto_region_name = region ,
763+ config = None ,
764+ local_mode = False ,
765+ s3_client = None ,
766+ s3_resource = None ,
767+ )
768+ f = DummyFramework (
769+ entry_point = SCRIPT_PATH ,
770+ role = ROLE ,
771+ sagemaker_session = sms ,
772+ instance_count = INSTANCE_COUNT ,
773+ instance_type = INSTANCE_TYPE ,
774+ )
775+ f .fit ("s3://mydata" )
776+ f .enable_default_profiling ()
777+
778+ assert "Current region does not support profiler / debugger!" in str (error )
779+
780+
781+ @pytest .mark .parametrize ("region" , PROFILER_UNSUPPORTED_REGIONS )
782+ def test_framework_update_profiling_in_unsupported_region (region ):
783+ with pytest .raises (ValueError ) as error :
784+ boto_mock = Mock (name = "boto_session" , region_name = region )
785+ sms = MagicMock (
786+ name = "sagemaker_session" ,
787+ boto_session = boto_mock ,
788+ boto_region_name = region ,
789+ config = None ,
790+ local_mode = False ,
791+ s3_client = None ,
792+ s3_resource = None ,
793+ )
794+ f = DummyFramework (
795+ entry_point = SCRIPT_PATH ,
796+ role = ROLE ,
797+ sagemaker_session = sms ,
798+ instance_count = INSTANCE_COUNT ,
799+ instance_type = INSTANCE_TYPE ,
800+ )
801+ f .fit ("s3://mydata" )
802+ f .update_profiler (system_monitor_interval_millis = 1000 )
803+
804+ assert "Current region does not support profiler / debugger!" in str (error )
805+
806+
807+ @pytest .mark .parametrize ("region" , PROFILER_UNSUPPORTED_REGIONS )
808+ def test_framework_disable_profiling_in_unsupported_region (region ):
809+ with pytest .raises (ValueError ) as error :
810+ boto_mock = Mock (name = "boto_session" , region_name = region )
811+ sms = MagicMock (
812+ name = "sagemaker_session" ,
813+ boto_session = boto_mock ,
814+ boto_region_name = region ,
815+ config = None ,
816+ local_mode = False ,
817+ s3_client = None ,
818+ s3_resource = None ,
819+ )
820+ f = DummyFramework (
821+ entry_point = SCRIPT_PATH ,
822+ role = ROLE ,
823+ sagemaker_session = sms ,
824+ instance_count = INSTANCE_COUNT ,
825+ instance_type = INSTANCE_TYPE ,
826+ )
827+ f .fit ("s3://mydata" )
828+ f .disable_profiling ()
829+
830+ assert "Current region does not support profiler / debugger!" in str (error )
831+
832+
729833def test_framework_with_profiler_config_and_profiler_disabled (sagemaker_session ):
730834 with pytest .raises (RuntimeError ) as error :
731835 f = DummyFramework (
@@ -2683,6 +2787,7 @@ def test_generic_to_fit_no_input(time, sagemaker_session):
26832787
26842788 args .pop ("job_name" )
26852789 args .pop ("role" )
2790+ args .pop ("debugger_hook_config" )
26862791
26872792 assert args == NO_INPUT_TRAIN_CALL
26882793
@@ -2707,6 +2812,7 @@ def test_generic_to_fit_no_hps(time, sagemaker_session):
27072812
27082813 args .pop ("job_name" )
27092814 args .pop ("role" )
2815+ args .pop ("debugger_hook_config" )
27102816
27112817 assert args == BASE_TRAIN_CALL
27122818
@@ -2733,6 +2839,7 @@ def test_generic_to_fit_with_hps(time, sagemaker_session):
27332839
27342840 args .pop ("job_name" )
27352841 args .pop ("role" )
2842+ args .pop ("debugger_hook_config" )
27362843
27372844 assert args == HP_TRAIN_CALL
27382845
@@ -2764,6 +2871,7 @@ def test_generic_to_fit_with_experiment_config(time, sagemaker_session):
27642871
27652872 args .pop ("job_name" )
27662873 args .pop ("role" )
2874+ args .pop ("debugger_hook_config" )
27672875
27682876 assert args == EXP_TRAIN_CALL
27692877
@@ -2917,6 +3025,7 @@ def test_generic_to_deploy(time, sagemaker_session):
29173025
29183026 args .pop ("job_name" )
29193027 args .pop ("role" )
3028+ args .pop ("debugger_hook_config" )
29203029
29213030 assert args == HP_TRAIN_CALL
29223031
@@ -3727,7 +3836,6 @@ def test_script_mode_estimator_same_calls_as_framework(
37273836 source_dir = script_uri ,
37283837 image_uri = IMAGE_URI ,
37293838 model_uri = model_uri ,
3730- environment = {"USE_SMDEBUG" : "0" },
37313839 dependencies = [],
37323840 debugger_hook_config = {},
37333841 )
0 commit comments