|
14 | 14 |
|
15 | 15 | import os |
16 | 16 | import re |
| 17 | +import time |
17 | 18 | import uuid |
18 | 19 |
|
19 | 20 | import pytest |
|
34 | 35 | from tests.integ.timeout import timeout |
35 | 36 |
|
36 | 37 |
|
| 38 | +TRAINING_STATUS = "Training" |
| 39 | +ALGO_PULL_FINISHED_MESSAGE = "Training image download completed. Training in progress." |
| 40 | + |
| 41 | + |
| 42 | +def _wait_until_training_can_be_updated(sagemaker_client, job_name, poll=5): |
| 43 | + ready_for_updating = _check_secondary_status(sagemaker_client, job_name) |
| 44 | + while not ready_for_updating: |
| 45 | + time.sleep(poll) |
| 46 | + ready_for_updating = _check_secondary_status(sagemaker_client, job_name) |
| 47 | + |
| 48 | + |
| 49 | +def _check_secondary_status(sagemaker_client, job_name): |
| 50 | + desc = sagemaker_client.describe_training_job(TrainingJobName=job_name) |
| 51 | + secondary_status_transitions = desc.get("SecondaryStatusTransitions") |
| 52 | + if not secondary_status_transitions: |
| 53 | + return False |
| 54 | + |
| 55 | + latest_secondary_status_transition = secondary_status_transitions[-1] |
| 56 | + secondary_status = latest_secondary_status_transition.get("Status") |
| 57 | + status_message = latest_secondary_status_transition.get("StatusMessage") |
| 58 | + return TRAINING_STATUS == secondary_status and ALGO_PULL_FINISHED_MESSAGE == status_message |
| 59 | + |
| 60 | + |
37 | 61 | def test_mxnet_with_default_profiler_config_and_profiler_rule( |
38 | 62 | sagemaker_session, |
39 | 63 | mxnet_training_latest_version, |
@@ -139,6 +163,8 @@ def test_mxnet_with_custom_profiler_config_then_update_rule_and_config( |
139 | 163 | ) |
140 | 164 | assert profiler_rule_configuration["RuleParameters"] == {"rule_to_invoke": "ProfilerReport"} |
141 | 165 |
|
| 166 | + _wait_until_training_can_be_updated(sagemaker_session.sagemaker_client, training_job_name) |
| 167 | + |
142 | 168 | mx.update_profiler( |
143 | 169 | rules=[ProfilerRule.sagemaker(rule_configs.CPUBottleneck())], |
144 | 170 | system_monitor_interval_millis=500, |
@@ -287,6 +313,8 @@ def test_mxnet_with_profiler_and_debugger_then_disable_framework_metrics( |
287 | 313 | == rule.image_uri |
288 | 314 | ) |
289 | 315 |
|
| 316 | + _wait_until_training_can_be_updated(sagemaker_session.sagemaker_client, training_job_name) |
| 317 | + |
290 | 318 | mx.update_profiler(disable_framework_metrics=True) |
291 | 319 | job_description = mx.latest_training_job.describe() |
292 | 320 | assert job_description["ProfilerConfig"]["ProfilingParameters"] == {} |
@@ -338,6 +366,8 @@ def test_mxnet_with_enable_framework_metrics_then_update_framework_metrics( |
338 | 366 | ) |
339 | 367 | assert job_description.get("ProfilingStatus") == "Enabled" |
340 | 368 |
|
| 369 | + _wait_until_training_can_be_updated(sagemaker_session.sagemaker_client, training_job_name) |
| 370 | + |
341 | 371 | updated_framework_profile = FrameworkProfile( |
342 | 372 | detailed_profiling_config=DetailedProfilingConfig(profile_default_steps=True) |
343 | 373 | ) |
@@ -397,6 +427,8 @@ def test_mxnet_with_disable_profiler_then_enable_default_profiling( |
397 | 427 | assert job_description.get("ProfilerRuleConfigurations") is None |
398 | 428 | assert job_description.get("ProfilingStatus") == "Disabled" |
399 | 429 |
|
| 430 | + _wait_until_training_can_be_updated(sagemaker_session.sagemaker_client, training_job_name) |
| 431 | + |
400 | 432 | mx.enable_default_profiling() |
401 | 433 |
|
402 | 434 | job_description = mx.latest_training_job.describe() |
|
0 commit comments