1111# ANY KIND, either express or implied. See the License for the specific
1212# language governing permissions and limitations under the License.
1313
14+ """This module contains code related to forwarding SageMaker TrainingJob Metrics to MLflow."""
15+
16+ from __future__ import absolute_import
1417
15- import boto3
1618import os
19+ import platform
20+ import re
21+ from typing import Set , Tuple , List , Dict , Generator
22+ import boto3
1723import mlflow
1824from mlflow import MlflowClient
1925from mlflow .entities import Metric , Param , RunTag
2026
2127from packaging import version
22- import platform
23- import re
24-
25- from typing import Set , Tuple , List , Dict
2628
2729
2830def encode (name : str , existing_names : Set [str ]) -> str :
29- """
30- Encode a string to comply with MLflow naming restrictions and ensure uniqueness.
31+ """Encode a string to comply with MLflow naming restrictions and ensure uniqueness.
3132
3233 Args:
3334 name (str): The original string to be encoded.
@@ -55,7 +56,8 @@ def encode_char(match):
5556
5657 if base_name in existing_names :
5758 suffix = 1
58- # Edge case where even with suffix space there is a collision we will override one of the keys.
59+ # Edge case where even with suffix space there is a collision
60+ # we will override one of the keys.
5961 while f"{ base_name } _{ suffix } " in existing_names :
6062 suffix += 1
6163 encoded = f"{ base_name } _{ suffix } "
@@ -68,9 +70,7 @@ def encode_char(match):
6870
6971
7072def decode (encoded_metric_name : str ) -> str :
71-
72- # TODO: Utilize the stored name mappings to get the original key mappings without having to decode.
73- """Decodes an encoded metric name by replacing hexadecimal representations with their corresponding characters.
73+ """Decodes an encoded metric name by replacing hexadecimal representations with ASCII
7474
7575 This function reverses the encoding process by converting hexadecimal codes
7676 back to their original characters. It looks for patterns of the form "_XX_"
@@ -100,8 +100,7 @@ def replace_code(match):
100100
101101
102102def get_training_job_details (job_arn : str ) -> dict :
103- """
104- Retrieve details of a SageMaker training job.
103+ """Retrieve details of a SageMaker training job.
105104
106105 Args:
107106 job_arn (str): The ARN of the SageMaker training job.
@@ -118,8 +117,7 @@ def get_training_job_details(job_arn: str) -> dict:
118117
119118
120119def create_metric_queries (job_arn : str , metric_definitions : list ) -> list :
121- """
122- Create metric queries for SageMaker metrics.
120+ """Create metric queries for SageMaker metrics.
123121
124122 Args:
125123 job_arn (str): The ARN of the SageMaker training job.
@@ -142,8 +140,7 @@ def create_metric_queries(job_arn: str, metric_definitions: list) -> list:
142140
143141
144142def get_metric_data (metric_queries : list ) -> dict :
145- """
146- Retrieve metric data from SageMaker.
143+ """Retrieve metric data from SageMaker.
147144
148145 Args:
149146 metric_queries (list): A list of metric queries.
@@ -162,8 +159,7 @@ def get_metric_data(metric_queries: list) -> dict:
162159def prepare_mlflow_metrics (
163160 metric_queries : list , metric_results : list
164161) -> Tuple [List [Metric ], Dict [str , str ]]:
165- """
166- Prepare metrics for MLflow logging, encoding metric names if necessary.
162+ """Prepare metrics for MLflow logging, encoding metric names if necessary.
167163
168164 Args:
169165 metric_queries (list): The original metric queries sent to SageMaker.
@@ -184,29 +180,26 @@ def prepare_mlflow_metrics(
184180 encoded_name = encode (metric_name , existing_names )
185181 metric_name_mapping [encoded_name ] = metric_name
186182
187- mlflow_metrics .extend (
188- [
189- Metric (key = encoded_name , value = value , timestamp = timestamp , step = step )
190- for step , (timestamp , value ) in enumerate (
191- zip (result ["XAxisValues" ], result ["MetricValues" ])
192- )
193- ]
194- )
183+ for step , (timestamp , value ) in enumerate (
184+ zip (result ["XAxisValues" ], result ["MetricValues" ])
185+ ):
186+ metric = Metric (key = encoded_name , value = value , timestamp = timestamp , step = step )
187+ mlflow_metrics .append (metric )
195188
196189 return mlflow_metrics , metric_name_mapping
197190
198191
199192def prepare_mlflow_params (hyperparameters : Dict [str , str ]) -> Tuple [List [Param ], Dict [str , str ]]:
200- """
201- Prepare hyperparameters for MLflow logging, encoding parameter names if necessary.
193+ """Prepare hyperparameters for MLflow logging, encoding parameter names if necessary.
202194
203195 Args:
204196 hyperparameters (Dict[str, str]): The hyperparameters from the SageMaker job.
205197
206198 Returns:
207199 Tuple[List[Param], Dict[str, str]]:
208200 - A list of Param objects with encoded names (if necessary)
209- - A mapping of encoded to original names for hyperparameters (only for encoded parameters)
201+ - A mapping of encoded to original names for
202+ hyperparameters (only for encoded parameters)
210203 """
211204 mlflow_params = []
212205 param_name_mapping = {}
@@ -220,9 +213,8 @@ def prepare_mlflow_params(hyperparameters: Dict[str, str]) -> Tuple[List[Param],
220213 return mlflow_params , param_name_mapping
221214
222215
223- def batch_items (items : list , batch_size : int ) -> list :
224- """
225- Yield successive batch_size chunks from items.
216+ def batch_items (items : list , batch_size : int ) -> Generator :
217+ """Yield successive batch_size chunks from items.
226218
227219 Args:
228220 items (list): The list of items to be batched.
@@ -236,8 +228,7 @@ def batch_items(items: list, batch_size: int) -> list:
236228
237229
238230def log_to_mlflow (metrics : list , params : list , tags : dict ) -> None :
239- """
240- Log metrics, parameters, and tags to MLflow.
231+ """Log metrics, parameters, and tags to MLflow.
241232
242233 Args:
243234 metrics (list): List of metrics to log.
@@ -278,8 +269,7 @@ def log_to_mlflow(metrics: list, params: list, tags: dict) -> None:
278269
279270
280271def log_sagemaker_job_to_mlflow (training_job_arn : str ) -> None :
281- """
282- Retrieve SageMaker metrics and hyperparameters and log them to MLflow.
272+ """Retrieve SageMaker metrics and hyperparameters and log them to MLflow.
283273
284274 Args:
285275 training_job_arn (str): The ARN of the SageMaker training job.
0 commit comments