Skip to content

Commit 8121734

Browse files
committed
Adding helpers for logging
1 parent 8194f1b commit 8121734

File tree

1 file changed

+375
-0
lines changed

1 file changed

+375
-0
lines changed
Lines changed: 375 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,375 @@
1+
import botocore
2+
import sys
3+
import time
4+
from sagemaker.core.logs import ColorWrap, Position, multi_stream_iter
5+
from sagemaker.core.common_utils import (
6+
secondary_training_status_changed,
7+
secondary_training_status_message,
8+
)
9+
10+
class LogState(object):
11+
"""Placeholder docstring"""
12+
STARTING = 1
13+
WAIT_IN_PROGRESS = 2
14+
TAILING = 3
15+
JOB_COMPLETE = 4
16+
COMPLETE = 5
17+
18+
STATUS_CODE_TABLE = {
19+
"COMPLETED": "Completed",
20+
"INPROGRESS": "InProgress",
21+
"IN_PROGRESS": "InProgress",
22+
"FAILED": "Failed",
23+
"STOPPED": "Stopped",
24+
"STOPPING": "Stopping",
25+
"STARTING": "Starting",
26+
"PENDING": "Pending",
27+
}
28+
29+
30+
def wait_until(callable_fn, poll=5):
31+
"""Placeholder docstring"""
32+
elapsed_time = 0
33+
result = None
34+
while result is None:
35+
try:
36+
elapsed_time += poll
37+
time.sleep(poll)
38+
result = callable_fn()
39+
except botocore.exceptions.ClientError as err:
40+
# For initial 5 mins we accept/pass AccessDeniedException.
41+
# The reason is to await tag propagation to avoid false AccessDenied claims for an
42+
# access policy based on resource tags, The caveat here is for true AccessDenied
43+
# cases the routine will fail after 5 mins
44+
if err.response["Error"]["Code"] == "AccessDeniedException" and elapsed_time <= 300:
45+
logger.warning(
46+
"Received AccessDeniedException. This could mean the IAM role does not "
47+
"have the resource permissions, in which case please add resource access "
48+
"and retry. For cases where the role has tag based resource policy, "
49+
"continuing to wait for tag propagation.."
50+
)
51+
continue
52+
raise err
53+
return result
54+
55+
56+
def get_initial_job_state(description, status_key, wait):
57+
"""Placeholder docstring"""
58+
status = description[status_key]
59+
job_already_completed = status in ("Completed", "Failed", "Stopped")
60+
return LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE
61+
62+
63+
def logs_init(boto_session, description, job):
64+
"""Placeholder docstring"""
65+
if job == "Training":
66+
if "InstanceGroups" in description["ResourceConfig"]:
67+
instance_count = 0
68+
for instanceGroup in description["ResourceConfig"]["InstanceGroups"]:
69+
instance_count += instanceGroup["InstanceCount"]
70+
else:
71+
instance_count = description["ResourceConfig"]["InstanceCount"]
72+
elif job == "Transform":
73+
instance_count = description["TransformResources"]["InstanceCount"]
74+
elif job == "Processing":
75+
instance_count = description["ProcessingResources"]["ClusterConfig"]["InstanceCount"]
76+
elif job == "AutoML":
77+
instance_count = 0
78+
79+
stream_names = [] # The list of log streams
80+
positions = {} # The current position in each stream, map of stream name -> position
81+
82+
# Increase retries allowed (from default of 4), as we don't want waiting for a training job
83+
# to be interrupted by a transient exception.
84+
config = botocore.config.Config(retries={"max_attempts": 15})
85+
client = boto_session.client("logs", config=config)
86+
log_group = "/aws/sagemaker/" + job + "Jobs"
87+
88+
dot = False
89+
90+
color_wrap = ColorWrap()
91+
92+
return instance_count, stream_names, positions, client, log_group, dot, color_wrap
93+
94+
95+
def flush_log_streams(
96+
stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap
97+
):
98+
"""Placeholder docstring"""
99+
if len(stream_names) < instance_count:
100+
# Log streams are created whenever a container starts writing to stdout/err, so this list
101+
# may be dynamic until we have a stream for every instance.
102+
try:
103+
streams = client.describe_log_streams(
104+
logGroupName=log_group,
105+
logStreamNamePrefix=job_name + "/",
106+
orderBy="LogStreamName",
107+
limit=min(instance_count, 50),
108+
)
109+
stream_names = [s["logStreamName"] for s in streams["logStreams"]]
110+
111+
while "nextToken" in streams:
112+
streams = client.describe_log_streams(
113+
logGroupName=log_group,
114+
logStreamNamePrefix=job_name + "/",
115+
orderBy="LogStreamName",
116+
limit=50,
117+
)
118+
119+
stream_names.extend([s["logStreamName"] for s in streams["logStreams"]])
120+
121+
positions.update(
122+
[
123+
(s, Position(timestamp=0, skip=0))
124+
for s in stream_names
125+
if s not in positions
126+
]
127+
)
128+
except ClientError as e:
129+
# On the very first training job run on an account, there's no log group until
130+
# the container starts logging, so ignore any errors thrown about that
131+
err = e.response.get("Error", {})
132+
if err.get("Code", None) != "ResourceNotFoundException":
133+
raise
134+
135+
if len(stream_names) > 0:
136+
if dot:
137+
print("")
138+
dot = False
139+
for idx, event in multi_stream_iter(
140+
client, log_group, stream_names, positions
141+
):
142+
color_wrap(idx, event["message"])
143+
ts, count = positions[stream_names[idx]]
144+
if event["timestamp"] == ts:
145+
positions[stream_names[idx]] = Position(
146+
timestamp=ts, skip=count + 1
147+
)
148+
else:
149+
positions[stream_names[idx]] = Position(
150+
timestamp=event["timestamp"], skip=1
151+
)
152+
else:
153+
dot = True
154+
print(".", end="")
155+
sys.stdout.flush()
156+
157+
158+
def rule_statuses_changed(current_statuses, last_statuses):
159+
"""Checks the rule evaluation statuses for SageMaker Debugger and Profiler rules."""
160+
if not last_statuses:
161+
return True
162+
163+
for current, last in zip(current_statuses, last_statuses):
164+
if (current["RuleConfigurationName"] == last["RuleConfigurationName"]) and (
165+
current["RuleEvaluationStatus"] != last["RuleEvaluationStatus"]
166+
):
167+
return True
168+
169+
return False
170+
171+
172+
def check_job_status(job, desc, status_key_name):
173+
"""Check to see if the job completed successfully.
174+
175+
If not, construct and raise a exceptions. (UnexpectedStatusException).
176+
177+
Args:
178+
job (str): The name of the job to check.
179+
desc (dict[str, str]): The result of ``describe_training_job()``.
180+
status_key_name (str): Status key name to check for.
181+
182+
Raises:
183+
exceptions.CapacityError: If the training job fails with CapacityError.
184+
exceptions.UnexpectedStatusException: If the training job fails.
185+
"""
186+
status = desc[status_key_name]
187+
# If the status is capital case, then convert it to Camel case
188+
status = STATUS_CODE_TABLE.get(status, status)
189+
190+
if status == "Stopped":
191+
logger.warning(
192+
"Job ended with status 'Stopped' rather than 'Completed'. "
193+
"This could mean the job timed out or stopped early for some other reason: "
194+
"Consider checking whether it completed as you expect."
195+
)
196+
elif status != "Completed":
197+
reason = desc.get("FailureReason", "(No reason provided)")
198+
job_type = status_key_name.replace("JobStatus", " job")
199+
troubleshooting = (
200+
"https://docs.aws.amazon.com/sagemaker/latest/dg/"
201+
"sagemaker-python-sdk-troubleshooting.html"
202+
)
203+
message = (
204+
"Error for {job_type} {job_name}: {status}. Reason: {reason}. "
205+
"Check troubleshooting guide for common errors: {troubleshooting}"
206+
).format(
207+
job_type=job_type,
208+
job_name=job,
209+
status=status,
210+
reason=reason,
211+
troubleshooting=troubleshooting,
212+
)
213+
if "CapacityError" in str(reason):
214+
raise exceptions.CapacityError(
215+
message=message,
216+
allowed_statuses=["Completed", "Stopped"],
217+
actual_status=status,
218+
)
219+
raise exceptions.UnexpectedStatusException(
220+
message=message,
221+
allowed_statuses=["Completed", "Stopped"],
222+
actual_status=status,
223+
)
224+
225+
226+
def logs_for_job(
227+
model_trainer, wait=False, poll=10, log_type="All", timeout=None
228+
):
229+
"""Display logs for a given training job, optionally tailing them until job is complete.
230+
231+
If the output is a tty or a Jupyter cell, it will be color-coded
232+
based on which instance the log entry is from.
233+
234+
Args:
235+
model_trainer (sagemaker.train.ModelTrainer): The ModelTrainer used for the
236+
training job
237+
wait (bool): Whether to keep looking for new log entries until the job completes
238+
(default: False).
239+
poll (int): The interval in seconds between polling for new log entries and job
240+
completion (default: 5).
241+
log_type ([str]): A list of strings specifying which logs to print. Acceptable
242+
strings are "All", "None", "Training", or "Rules". To maintain backwards
243+
compatibility, boolean values are also accepted and converted to strings.
244+
timeout (int): Timeout in seconds to wait until the job is completed. ``None`` by
245+
default.
246+
Returns:
247+
Last call to sagemaker DescribeTrainingJob
248+
Raises:
249+
exceptions.CapacityError: If the training job fails with CapacityError.
250+
exceptions.UnexpectedStatusException: If waiting and the training job fails.
251+
"""
252+
sagemaker_session = model_trainer.sagemaker_session
253+
job_name = model_trainer._latest_training_job.training_job_name
254+
255+
sagemaker_client = sagemaker_session.sagemaker_client
256+
request_end_time = time.time() + timeout if timeout else None
257+
description = wait_until(
258+
lambda: sagemaker_client.describe_training_job(TrainingJobName=job_name)
259+
)
260+
print(secondary_training_status_message(description, None), end="")
261+
262+
instance_count, stream_names, positions, client, log_group, dot, color_wrap = logs_init(
263+
sagemaker_session.boto_session, description, job="Training"
264+
)
265+
266+
state = get_initial_job_state(description, "TrainingJobStatus", wait)
267+
268+
# The loop below implements a state machine that alternates between checking the job status
269+
# and reading whatever is available in the logs at this point. Note, that if we were
270+
# called with wait == False, we never check the job status.
271+
#
272+
# If wait == TRUE and job is not completed, the initial state is TAILING
273+
# If wait == FALSE, the initial state is COMPLETE (doesn't matter if the job really is
274+
# complete).
275+
#
276+
# The state table:
277+
#
278+
# STATE ACTIONS CONDITION NEW STATE
279+
# ---------------- ---------------- ----------------- ----------------
280+
# TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE
281+
# Else TAILING
282+
# JOB_COMPLETE Read logs, Pause Any COMPLETE
283+
# COMPLETE Read logs, Exit N/A
284+
#
285+
# Notes:
286+
# - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to
287+
# Cloudwatch after the job was marked complete.
288+
last_describe_job_call = time.time()
289+
last_description = description
290+
last_debug_rule_statuses = None
291+
last_profiler_rule_statuses = None
292+
293+
while True:
294+
flush_log_streams(
295+
stream_names,
296+
instance_count,
297+
client,
298+
log_group,
299+
job_name,
300+
positions,
301+
dot,
302+
color_wrap,
303+
)
304+
if timeout and time.time() > request_end_time:
305+
print("Timeout Exceeded. {} seconds elapsed.".format(timeout))
306+
break
307+
308+
if state == LogState.COMPLETE:
309+
break
310+
311+
time.sleep(poll)
312+
313+
if state == LogState.JOB_COMPLETE:
314+
state = LogState.COMPLETE
315+
elif time.time() - last_describe_job_call >= 30:
316+
description = sagemaker_client.describe_training_job(TrainingJobName=job_name)
317+
last_describe_job_call = time.time()
318+
319+
if secondary_training_status_changed(description, last_description):
320+
print()
321+
print(secondary_training_status_message(description, last_description), end="")
322+
last_description = description
323+
324+
status = description["TrainingJobStatus"]
325+
326+
if status in ("Completed", "Failed", "Stopped"):
327+
print()
328+
state = LogState.JOB_COMPLETE
329+
330+
# Print prettified logs related to the status of SageMaker Debugger rules.
331+
debug_rule_statuses = description.get("DebugRuleEvaluationStatuses", {})
332+
if (
333+
debug_rule_statuses
334+
and rule_statuses_changed(debug_rule_statuses, last_debug_rule_statuses)
335+
and (log_type in {"All", "Rules"})
336+
):
337+
for status in debug_rule_statuses:
338+
rule_log = (
339+
f"{status['RuleConfigurationName']}: {status['RuleEvaluationStatus']}"
340+
)
341+
print(rule_log)
342+
343+
last_debug_rule_statuses = debug_rule_statuses
344+
345+
# Print prettified logs related to the status of SageMaker Profiler rules.
346+
profiler_rule_statuses = description.get("ProfilerRuleEvaluationStatuses", {})
347+
if (
348+
profiler_rule_statuses
349+
and rule_statuses_changed(profiler_rule_statuses, last_profiler_rule_statuses)
350+
and (log_type in {"All", "Rules"})
351+
):
352+
for status in profiler_rule_statuses:
353+
rule_log = (
354+
f"{status['RuleConfigurationName']}: {status['RuleEvaluationStatus']}"
355+
)
356+
print(rule_log)
357+
358+
last_profiler_rule_statuses = profiler_rule_statuses
359+
360+
if wait:
361+
check_job_status(job_name, description, "TrainingJobStatus")
362+
if dot:
363+
print()
364+
# Customers are not billed for hardware provisioning, so billable time is less than
365+
# total time
366+
training_time = description.get("TrainingTimeInSeconds")
367+
billable_time = description.get("BillableTimeInSeconds")
368+
if training_time is not None:
369+
print("Training seconds:", training_time * instance_count)
370+
if billable_time is not None:
371+
print("Billable seconds:", billable_time * instance_count)
372+
if description.get("EnableManagedSpotTraining"):
373+
saving = (1 - float(billable_time) / training_time) * 100
374+
print("Managed Spot Training savings: {:.1f}%".format(saving))
375+
return last_description

0 commit comments

Comments
 (0)