1+ import botocore
2+ import sys
3+ import time
4+ from sagemaker .core .logs import ColorWrap , Position , multi_stream_iter
5+ from sagemaker .core .common_utils import (
6+ secondary_training_status_changed ,
7+ secondary_training_status_message ,
8+ )
9+
10+ class LogState (object ):
11+ """Placeholder docstring"""
12+ STARTING = 1
13+ WAIT_IN_PROGRESS = 2
14+ TAILING = 3
15+ JOB_COMPLETE = 4
16+ COMPLETE = 5
17+
18+ STATUS_CODE_TABLE = {
19+ "COMPLETED" : "Completed" ,
20+ "INPROGRESS" : "InProgress" ,
21+ "IN_PROGRESS" : "InProgress" ,
22+ "FAILED" : "Failed" ,
23+ "STOPPED" : "Stopped" ,
24+ "STOPPING" : "Stopping" ,
25+ "STARTING" : "Starting" ,
26+ "PENDING" : "Pending" ,
27+ }
28+
29+
30+ def wait_until (callable_fn , poll = 5 ):
31+ """Placeholder docstring"""
32+ elapsed_time = 0
33+ result = None
34+ while result is None :
35+ try :
36+ elapsed_time += poll
37+ time .sleep (poll )
38+ result = callable_fn ()
39+ except botocore .exceptions .ClientError as err :
40+ # For initial 5 mins we accept/pass AccessDeniedException.
41+ # The reason is to await tag propagation to avoid false AccessDenied claims for an
42+ # access policy based on resource tags, The caveat here is for true AccessDenied
43+ # cases the routine will fail after 5 mins
44+ if err .response ["Error" ]["Code" ] == "AccessDeniedException" and elapsed_time <= 300 :
45+ logger .warning (
46+ "Received AccessDeniedException. This could mean the IAM role does not "
47+ "have the resource permissions, in which case please add resource access "
48+ "and retry. For cases where the role has tag based resource policy, "
49+ "continuing to wait for tag propagation.."
50+ )
51+ continue
52+ raise err
53+ return result
54+
55+
56+ def get_initial_job_state (description , status_key , wait ):
57+ """Placeholder docstring"""
58+ status = description [status_key ]
59+ job_already_completed = status in ("Completed" , "Failed" , "Stopped" )
60+ return LogState .TAILING if wait and not job_already_completed else LogState .COMPLETE
61+
62+
63+ def logs_init (boto_session , description , job ):
64+ """Placeholder docstring"""
65+ if job == "Training" :
66+ if "InstanceGroups" in description ["ResourceConfig" ]:
67+ instance_count = 0
68+ for instanceGroup in description ["ResourceConfig" ]["InstanceGroups" ]:
69+ instance_count += instanceGroup ["InstanceCount" ]
70+ else :
71+ instance_count = description ["ResourceConfig" ]["InstanceCount" ]
72+ elif job == "Transform" :
73+ instance_count = description ["TransformResources" ]["InstanceCount" ]
74+ elif job == "Processing" :
75+ instance_count = description ["ProcessingResources" ]["ClusterConfig" ]["InstanceCount" ]
76+ elif job == "AutoML" :
77+ instance_count = 0
78+
79+ stream_names = [] # The list of log streams
80+ positions = {} # The current position in each stream, map of stream name -> position
81+
82+ # Increase retries allowed (from default of 4), as we don't want waiting for a training job
83+ # to be interrupted by a transient exception.
84+ config = botocore .config .Config (retries = {"max_attempts" : 15 })
85+ client = boto_session .client ("logs" , config = config )
86+ log_group = "/aws/sagemaker/" + job + "Jobs"
87+
88+ dot = False
89+
90+ color_wrap = ColorWrap ()
91+
92+ return instance_count , stream_names , positions , client , log_group , dot , color_wrap
93+
94+
95+ def flush_log_streams (
96+ stream_names , instance_count , client , log_group , job_name , positions , dot , color_wrap
97+ ):
98+ """Placeholder docstring"""
99+ if len (stream_names ) < instance_count :
100+ # Log streams are created whenever a container starts writing to stdout/err, so this list
101+ # may be dynamic until we have a stream for every instance.
102+ try :
103+ streams = client .describe_log_streams (
104+ logGroupName = log_group ,
105+ logStreamNamePrefix = job_name + "/" ,
106+ orderBy = "LogStreamName" ,
107+ limit = min (instance_count , 50 ),
108+ )
109+ stream_names = [s ["logStreamName" ] for s in streams ["logStreams" ]]
110+
111+ while "nextToken" in streams :
112+ streams = client .describe_log_streams (
113+ logGroupName = log_group ,
114+ logStreamNamePrefix = job_name + "/" ,
115+ orderBy = "LogStreamName" ,
116+ limit = 50 ,
117+ )
118+
119+ stream_names .extend ([s ["logStreamName" ] for s in streams ["logStreams" ]])
120+
121+ positions .update (
122+ [
123+ (s , Position (timestamp = 0 , skip = 0 ))
124+ for s in stream_names
125+ if s not in positions
126+ ]
127+ )
128+ except ClientError as e :
129+ # On the very first training job run on an account, there's no log group until
130+ # the container starts logging, so ignore any errors thrown about that
131+ err = e .response .get ("Error" , {})
132+ if err .get ("Code" , None ) != "ResourceNotFoundException" :
133+ raise
134+
135+ if len (stream_names ) > 0 :
136+ if dot :
137+ print ("" )
138+ dot = False
139+ for idx , event in multi_stream_iter (
140+ client , log_group , stream_names , positions
141+ ):
142+ color_wrap (idx , event ["message" ])
143+ ts , count = positions [stream_names [idx ]]
144+ if event ["timestamp" ] == ts :
145+ positions [stream_names [idx ]] = Position (
146+ timestamp = ts , skip = count + 1
147+ )
148+ else :
149+ positions [stream_names [idx ]] = Position (
150+ timestamp = event ["timestamp" ], skip = 1
151+ )
152+ else :
153+ dot = True
154+ print ("." , end = "" )
155+ sys .stdout .flush ()
156+
157+
158+ def rule_statuses_changed (current_statuses , last_statuses ):
159+ """Checks the rule evaluation statuses for SageMaker Debugger and Profiler rules."""
160+ if not last_statuses :
161+ return True
162+
163+ for current , last in zip (current_statuses , last_statuses ):
164+ if (current ["RuleConfigurationName" ] == last ["RuleConfigurationName" ]) and (
165+ current ["RuleEvaluationStatus" ] != last ["RuleEvaluationStatus" ]
166+ ):
167+ return True
168+
169+ return False
170+
171+
172+ def check_job_status (job , desc , status_key_name ):
173+ """Check to see if the job completed successfully.
174+
175+ If not, construct and raise a exceptions. (UnexpectedStatusException).
176+
177+ Args:
178+ job (str): The name of the job to check.
179+ desc (dict[str, str]): The result of ``describe_training_job()``.
180+ status_key_name (str): Status key name to check for.
181+
182+ Raises:
183+ exceptions.CapacityError: If the training job fails with CapacityError.
184+ exceptions.UnexpectedStatusException: If the training job fails.
185+ """
186+ status = desc [status_key_name ]
187+ # If the status is capital case, then convert it to Camel case
188+ status = STATUS_CODE_TABLE .get (status , status )
189+
190+ if status == "Stopped" :
191+ logger .warning (
192+ "Job ended with status 'Stopped' rather than 'Completed'. "
193+ "This could mean the job timed out or stopped early for some other reason: "
194+ "Consider checking whether it completed as you expect."
195+ )
196+ elif status != "Completed" :
197+ reason = desc .get ("FailureReason" , "(No reason provided)" )
198+ job_type = status_key_name .replace ("JobStatus" , " job" )
199+ troubleshooting = (
200+ "https://docs.aws.amazon.com/sagemaker/latest/dg/"
201+ "sagemaker-python-sdk-troubleshooting.html"
202+ )
203+ message = (
204+ "Error for {job_type} {job_name}: {status}. Reason: {reason}. "
205+ "Check troubleshooting guide for common errors: {troubleshooting}"
206+ ).format (
207+ job_type = job_type ,
208+ job_name = job ,
209+ status = status ,
210+ reason = reason ,
211+ troubleshooting = troubleshooting ,
212+ )
213+ if "CapacityError" in str (reason ):
214+ raise exceptions .CapacityError (
215+ message = message ,
216+ allowed_statuses = ["Completed" , "Stopped" ],
217+ actual_status = status ,
218+ )
219+ raise exceptions .UnexpectedStatusException (
220+ message = message ,
221+ allowed_statuses = ["Completed" , "Stopped" ],
222+ actual_status = status ,
223+ )
224+
225+
226+ def logs_for_job (
227+ model_trainer , wait = False , poll = 10 , log_type = "All" , timeout = None
228+ ):
229+ """Display logs for a given training job, optionally tailing them until job is complete.
230+
231+ If the output is a tty or a Jupyter cell, it will be color-coded
232+ based on which instance the log entry is from.
233+
234+ Args:
235+ model_trainer (sagemaker.train.ModelTrainer): The ModelTrainer used for the
236+ training job
237+ wait (bool): Whether to keep looking for new log entries until the job completes
238+ (default: False).
239+ poll (int): The interval in seconds between polling for new log entries and job
240+ completion (default: 5).
241+ log_type ([str]): A list of strings specifying which logs to print. Acceptable
242+ strings are "All", "None", "Training", or "Rules". To maintain backwards
243+ compatibility, boolean values are also accepted and converted to strings.
244+ timeout (int): Timeout in seconds to wait until the job is completed. ``None`` by
245+ default.
246+ Returns:
247+ Last call to sagemaker DescribeTrainingJob
248+ Raises:
249+ exceptions.CapacityError: If the training job fails with CapacityError.
250+ exceptions.UnexpectedStatusException: If waiting and the training job fails.
251+ """
252+ sagemaker_session = model_trainer .sagemaker_session
253+ job_name = model_trainer ._latest_training_job .training_job_name
254+
255+ sagemaker_client = sagemaker_session .sagemaker_client
256+ request_end_time = time .time () + timeout if timeout else None
257+ description = wait_until (
258+ lambda : sagemaker_client .describe_training_job (TrainingJobName = job_name )
259+ )
260+ print (secondary_training_status_message (description , None ), end = "" )
261+
262+ instance_count , stream_names , positions , client , log_group , dot , color_wrap = logs_init (
263+ sagemaker_session .boto_session , description , job = "Training"
264+ )
265+
266+ state = get_initial_job_state (description , "TrainingJobStatus" , wait )
267+
268+ # The loop below implements a state machine that alternates between checking the job status
269+ # and reading whatever is available in the logs at this point. Note, that if we were
270+ # called with wait == False, we never check the job status.
271+ #
272+ # If wait == TRUE and job is not completed, the initial state is TAILING
273+ # If wait == FALSE, the initial state is COMPLETE (doesn't matter if the job really is
274+ # complete).
275+ #
276+ # The state table:
277+ #
278+ # STATE ACTIONS CONDITION NEW STATE
279+ # ---------------- ---------------- ----------------- ----------------
280+ # TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE
281+ # Else TAILING
282+ # JOB_COMPLETE Read logs, Pause Any COMPLETE
283+ # COMPLETE Read logs, Exit N/A
284+ #
285+ # Notes:
286+ # - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to
287+ # Cloudwatch after the job was marked complete.
288+ last_describe_job_call = time .time ()
289+ last_description = description
290+ last_debug_rule_statuses = None
291+ last_profiler_rule_statuses = None
292+
293+ while True :
294+ flush_log_streams (
295+ stream_names ,
296+ instance_count ,
297+ client ,
298+ log_group ,
299+ job_name ,
300+ positions ,
301+ dot ,
302+ color_wrap ,
303+ )
304+ if timeout and time .time () > request_end_time :
305+ print ("Timeout Exceeded. {} seconds elapsed." .format (timeout ))
306+ break
307+
308+ if state == LogState .COMPLETE :
309+ break
310+
311+ time .sleep (poll )
312+
313+ if state == LogState .JOB_COMPLETE :
314+ state = LogState .COMPLETE
315+ elif time .time () - last_describe_job_call >= 30 :
316+ description = sagemaker_client .describe_training_job (TrainingJobName = job_name )
317+ last_describe_job_call = time .time ()
318+
319+ if secondary_training_status_changed (description , last_description ):
320+ print ()
321+ print (secondary_training_status_message (description , last_description ), end = "" )
322+ last_description = description
323+
324+ status = description ["TrainingJobStatus" ]
325+
326+ if status in ("Completed" , "Failed" , "Stopped" ):
327+ print ()
328+ state = LogState .JOB_COMPLETE
329+
330+ # Print prettified logs related to the status of SageMaker Debugger rules.
331+ debug_rule_statuses = description .get ("DebugRuleEvaluationStatuses" , {})
332+ if (
333+ debug_rule_statuses
334+ and rule_statuses_changed (debug_rule_statuses , last_debug_rule_statuses )
335+ and (log_type in {"All" , "Rules" })
336+ ):
337+ for status in debug_rule_statuses :
338+ rule_log = (
339+ f"{ status ['RuleConfigurationName' ]} : { status ['RuleEvaluationStatus' ]} "
340+ )
341+ print (rule_log )
342+
343+ last_debug_rule_statuses = debug_rule_statuses
344+
345+ # Print prettified logs related to the status of SageMaker Profiler rules.
346+ profiler_rule_statuses = description .get ("ProfilerRuleEvaluationStatuses" , {})
347+ if (
348+ profiler_rule_statuses
349+ and rule_statuses_changed (profiler_rule_statuses , last_profiler_rule_statuses )
350+ and (log_type in {"All" , "Rules" })
351+ ):
352+ for status in profiler_rule_statuses :
353+ rule_log = (
354+ f"{ status ['RuleConfigurationName' ]} : { status ['RuleEvaluationStatus' ]} "
355+ )
356+ print (rule_log )
357+
358+ last_profiler_rule_statuses = profiler_rule_statuses
359+
360+ if wait :
361+ check_job_status (job_name , description , "TrainingJobStatus" )
362+ if dot :
363+ print ()
364+ # Customers are not billed for hardware provisioning, so billable time is less than
365+ # total time
366+ training_time = description .get ("TrainingTimeInSeconds" )
367+ billable_time = description .get ("BillableTimeInSeconds" )
368+ if training_time is not None :
369+ print ("Training seconds:" , training_time * instance_count )
370+ if billable_time is not None :
371+ print ("Billable seconds:" , billable_time * instance_count )
372+ if description .get ("EnableManagedSpotTraining" ):
373+ saving = (1 - float (billable_time ) / training_time ) * 100
374+ print ("Managed Spot Training savings: {:.1f}%" .format (saving ))
375+ return last_description
0 commit comments