6060# TODO-reinvent-2019: test get_debugger_artifacts_path and get_tensorboard_artifacts_path
6161
6262
63+ @pytest .fixture
64+ def actions ():
65+ return rule_configs .ActionList (
66+ rule_configs .StopTraining (),
67+ rule_configs .
Email (
"[email protected] " ),
68+ rule_configs .SMS ("+01234567890" ),
69+ )
70+
71+
6372def test_mxnet_with_rules (
6473 sagemaker_session ,
6574 mxnet_training_latest_version ,
@@ -125,6 +134,74 @@ def test_mxnet_with_rules(
125134 _wait_and_assert_that_no_rule_jobs_errored (training_job = mx .latest_training_job )
126135
127136
137+ def test_mxnet_with_rules_and_actions (
138+ sagemaker_session ,
139+ mxnet_training_latest_version ,
140+ mxnet_training_latest_py_version ,
141+ cpu_instance_type ,
142+ actions ,
143+ ):
144+ with timeout (minutes = TRAINING_DEFAULT_TIMEOUT_MINUTES ):
145+ rules = [
146+ Rule .sagemaker (rule_configs .vanishing_gradient (), actions = actions ),
147+ Rule .sagemaker (
148+ base_config = rule_configs .all_zero (),
149+ rule_parameters = {"tensor_regex" : ".*" },
150+ actions = actions ,
151+ ),
152+ Rule .sagemaker (rule_configs .loss_not_decreasing (), actions = actions ),
153+ ]
154+
155+ script_path = os .path .join (DATA_DIR , "mxnet_mnist" , "mnist_gluon.py" )
156+ data_path = os .path .join (DATA_DIR , "mxnet_mnist" )
157+
158+ mx = MXNet (
159+ entry_point = script_path ,
160+ role = "SageMakerRole" ,
161+ framework_version = mxnet_training_latest_version ,
162+ py_version = mxnet_training_latest_py_version ,
163+ instance_count = 1 ,
164+ instance_type = cpu_instance_type ,
165+ sagemaker_session = sagemaker_session ,
166+ rules = rules ,
167+ )
168+
169+ train_input = mx .sagemaker_session .upload_data (
170+ path = os .path .join (data_path , "train" ), key_prefix = "integ-test-data/mxnet_mnist/train"
171+ )
172+ test_input = mx .sagemaker_session .upload_data (
173+ path = os .path .join (data_path , "test" ), key_prefix = "integ-test-data/mxnet_mnist/test"
174+ )
175+
176+ mx .fit ({"train" : train_input , "test" : test_input })
177+
178+ job_description = mx .latest_training_job .describe ()
179+
180+ for index , rule in enumerate (rules ):
181+ assert (
182+ job_description ["DebugRuleConfigurations" ][index ]["RuleConfigurationName" ]
183+ == rule .name
184+ )
185+ assert (
186+ job_description ["DebugRuleConfigurations" ][index ]["RuleEvaluatorImage" ]
187+ == rule .image_uri
188+ )
189+ assert job_description ["DebugRuleConfigurations" ][index ]["VolumeSizeInGB" ] == 0
190+ assert (
191+ job_description ["DebugRuleConfigurations" ][index ]["RuleParameters" ][
192+ "rule_to_invoke"
193+ ]
194+ == rule .rule_parameters ["rule_to_invoke" ]
195+ )
196+
197+ assert (
198+ _get_rule_evaluation_statuses (job_description )
199+ == mx .latest_training_job .rule_job_summary ()
200+ )
201+
202+ _wait_and_assert_that_no_rule_jobs_errored (training_job = mx .latest_training_job )
203+
204+
128205def test_mxnet_with_custom_rule (
129206 sagemaker_session ,
130207 mxnet_training_latest_version ,
@@ -178,6 +255,60 @@ def test_mxnet_with_custom_rule(
178255 _wait_and_assert_that_no_rule_jobs_errored (training_job = mx .latest_training_job )
179256
180257
258+ def test_mxnet_with_custom_rule_and_actions (
259+ sagemaker_session ,
260+ mxnet_training_latest_version ,
261+ mxnet_training_latest_py_version ,
262+ cpu_instance_type ,
263+ actions ,
264+ ):
265+ with timeout (minutes = TRAINING_DEFAULT_TIMEOUT_MINUTES ):
266+ rules = [_get_custom_rule (sagemaker_session , actions )]
267+
268+ script_path = os .path .join (DATA_DIR , "mxnet_mnist" , "mnist_gluon.py" )
269+ data_path = os .path .join (DATA_DIR , "mxnet_mnist" )
270+
271+ mx = MXNet (
272+ entry_point = script_path ,
273+ role = "SageMakerRole" ,
274+ framework_version = mxnet_training_latest_version ,
275+ py_version = mxnet_training_latest_py_version ,
276+ instance_count = 1 ,
277+ instance_type = cpu_instance_type ,
278+ sagemaker_session = sagemaker_session ,
279+ rules = rules ,
280+ )
281+
282+ train_input = mx .sagemaker_session .upload_data (
283+ path = os .path .join (data_path , "train" ), key_prefix = "integ-test-data/mxnet_mnist/train"
284+ )
285+ test_input = mx .sagemaker_session .upload_data (
286+ path = os .path .join (data_path , "test" ), key_prefix = "integ-test-data/mxnet_mnist/test"
287+ )
288+
289+ mx .fit ({"train" : train_input , "test" : test_input })
290+
291+ job_description = mx .latest_training_job .describe ()
292+
293+ for index , rule in enumerate (rules ):
294+ assert (
295+ job_description ["DebugRuleConfigurations" ][index ]["RuleConfigurationName" ]
296+ == rule .name
297+ )
298+ assert (
299+ job_description ["DebugRuleConfigurations" ][index ]["RuleEvaluatorImage" ]
300+ == rule .image_uri
301+ )
302+ assert job_description ["DebugRuleConfigurations" ][index ]["VolumeSizeInGB" ] == 30
303+
304+ assert (
305+ _get_rule_evaluation_statuses (job_description )
306+ == mx .latest_training_job .rule_job_summary ()
307+ )
308+
309+ _wait_and_assert_that_no_rule_jobs_errored (training_job = mx .latest_training_job )
310+
311+
181312def test_mxnet_with_debugger_hook_config (
182313 sagemaker_session ,
183314 mxnet_training_latest_version ,
@@ -514,7 +645,7 @@ def _get_rule_evaluation_statuses(job_description):
514645 return debug_rule_eval_statuses + profiler_rule_eval_statuses
515646
516647
517- def _get_custom_rule (session ):
648+ def _get_custom_rule (session , actions = None ):
518649 script_path = os .path .join (DATA_DIR , "mxnet_mnist" , "my_custom_rule.py" )
519650
520651 return Rule .custom (
@@ -526,6 +657,7 @@ def _get_custom_rule(session):
526657 image_uri = CUSTOM_RULE_REPO_WITH_PLACEHOLDERS .format (
527658 CUSTOM_RULE_CONTAINERS_ACCOUNTS_MAP [session .boto_region_name ], session .boto_region_name
528659 ),
660+ actions = actions ,
529661 )
530662
531663
0 commit comments