2020from six import with_metaclass
2121
2222from sagemaker .session import Session
23- from sagemaker .utils import DeferredError
23+ from sagemaker .utils import DeferredError , extract_name_from_job_arn
2424
2525try :
2626 import pandas as pd
@@ -201,12 +201,13 @@ class TrainingJobAnalytics(AnalyticsMetricsBase):
201201
202202 CLOUDWATCH_NAMESPACE = '/aws/sagemaker/HyperParameterTuningJobs'
203203
204- def __init__ (self , training_job_name , metric_names , sagemaker_session = None ):
204+ def __init__ (self , training_job_name , metric_names = None , sagemaker_session = None ):
205205 """Initialize a ``TrainingJobAnalytics`` instance.
206206
207207 Args:
208208 training_job_name (str): name of the TrainingJob to analyze.
209- metric_names (list): string names of all the metrics to collect for this training job
209+ metric_names (list, optional): string names of all the metrics to collect for this training job.
210+ If not specified, then it will use all metric names configured for this job.
210211 sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
211212 Amazon SageMaker APIs and any other AWS services needed. If not specified, one is specified
212213 using the default AWS configuration chain.
@@ -215,7 +216,10 @@ def __init__(self, training_job_name, metric_names, sagemaker_session=None):
215216 self ._sage_client = sagemaker_session .sagemaker_client
216217 self ._cloudwatch = sagemaker_session .boto_session .client ('cloudwatch' )
217218 self ._training_job_name = training_job_name
218- self ._metric_names = metric_names
219+ if metric_names :
220+ self ._metric_names = metric_names
221+ else :
222+ self ._metric_names = self ._metric_names_for_training_job ()
219223 self .clear_cache ()
220224
221225 @property
@@ -297,3 +301,22 @@ def _add_single_metric(self, timestamp, metric_name, value):
297301 self ._data ['timestamp' ].append (timestamp )
298302 self ._data ['metric_name' ].append (metric_name )
299303 self ._data ['value' ].append (value )
304+
305+ def _metric_names_for_training_job (self ):
306+ """Helper method to discover the metrics defined for a training job.
307+ """
308+ # First look up the tuning job
309+ training_description = self ._sage_client .describe_training_job (TrainingJobName = self ._training_job_name )
310+ tuning_job_arn = training_description .get ('TuningJobArn' , None )
311+ if not tuning_job_arn :
312+ raise ValueError (
313+ "No metrics available. Training Job Analytics only available through Hyperparameter Tuning Jobs"
314+ )
315+ tuning_job_name = extract_name_from_job_arn (tuning_job_arn )
316+ tuning_job_description = self ._sage_client .describe_hyper_parameter_tuning_job (
317+ HyperParameterTuningJobName = tuning_job_name
318+ )
319+ training_job_definition = tuning_job_description ['TrainingJobDefinition' ]
320+ metric_definitions = training_job_definition ['AlgorithmSpecification' ]['MetricDefinitions' ]
321+ metric_names = [md ['Name' ] for md in metric_definitions ]
322+ return metric_names
0 commit comments