diff --git a/pyproject.toml b/pyproject.toml index 1c60e786e2..80dec58b99 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ dependencies = [ "PyYAML~=6.0", "requests", "sagemaker-core>=1.0.0,<2.0.0", + "sagemaker-mlflow", "schema", "smdebug_rulesconfig==1.0.1", "tblib>=1.7.0,<4", diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 6f02fde8e8..2aba19b112 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -107,6 +107,8 @@ from sagemaker.workflow.parameters import ParameterString from sagemaker.workflow.pipeline_context import PipelineSession, runnable_by_pipeline +from sagemaker.mlflow.forward_sagemaker_metrics import log_sagemaker_job_to_mlflow + logger = logging.getLogger(__name__) @@ -1366,8 +1368,14 @@ def fit( experiment_config = check_and_get_run_experiment_config(experiment_config) self.latest_training_job = _TrainingJob.start_new(self, inputs, experiment_config) self.jobs.append(self.latest_training_job) + forward_to_mlflow_tracking_server = False + if os.environ.get("MLFLOW_TRACKING_URI") and self.enable_network_isolation(): + wait = True + forward_to_mlflow_tracking_server = True if wait: self.latest_training_job.wait(logs=logs) + if forward_to_mlflow_tracking_server: + log_sagemaker_job_to_mlflow(self.latest_training_job.name) def _compilation_job_name(self): """Placeholder docstring""" diff --git a/src/sagemaker/mlflow/forward_sagemaker_metrics.py b/src/sagemaker/mlflow/forward_sagemaker_metrics.py new file mode 100644 index 0000000000..a025241d6a --- /dev/null +++ b/src/sagemaker/mlflow/forward_sagemaker_metrics.py @@ -0,0 +1,311 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +"""This module contains code related to forwarding SageMaker TrainingJob Metrics to MLflow.""" + +from __future__ import absolute_import + +import os +import platform +import re +from typing import Set, Tuple, List, Dict, Generator +import boto3 +import mlflow +from mlflow import MlflowClient +from mlflow.entities import Metric, Param, RunTag + +from packaging import version + + +def encode(name: str, existing_names: Set[str]) -> str: + """Encode a string to comply with MLflow naming restrictions and ensure uniqueness. + + Args: + name (str): The original string to be encoded. + existing_names (Set[str]): Set of existing encoded names to avoid collisions. + + Returns: + str: The encoded string if changes were necessary, otherwise the original string. + """ + + def encode_char(match): + return f"_{ord(match.group(0)):02x}_" + + # Check if we're on Mac/Unix and using MLflow 2.16.0 or greater + is_unix = platform.system() != "Windows" + mlflow_version = version.parse(mlflow.__version__) + allow_colon = is_unix and mlflow_version >= version.parse("2.16.0") + + if allow_colon: + pattern = r"[^\w\-./:\s]" + else: + pattern = r"[^\w\-./\s]" + + encoded = re.sub(pattern, encode_char, name) + base_name = encoded[:240] # Leave room for potential suffix to accommodate duplicates + + if base_name in existing_names: + suffix = 1 + # Edge case where even with suffix space there is a collision + # we will override one of the keys. + while f"{base_name}_{suffix}" in existing_names: + suffix += 1 + encoded = f"{base_name}_{suffix}" + + # Max length is 250 for mlflow metric/params + encoded = encoded[:250] + + existing_names.add(encoded) + return encoded + + +def decode(encoded_metric_name: str) -> str: + """Decodes an encoded metric name by replacing hexadecimal representations with ASCII + + This function reverses the encoding process by converting hexadecimal codes + back to their original characters. It looks for patterns of the form "_XX_" + where XX is a two-digit hexadecimal code, and replaces them with the + corresponding ASCII character. + + Args: + encoded_metric_name (str): The encoded metric name to be decoded. + + Returns: + str: The decoded metric name with hexadecimal codes replaced by their + corresponding characters. + + Example: + >>> decode("loss_3a_val") + "loss:val" + """ + + def replace_code(match): + code = match.group(1) + return chr(int(code, 16)) + + # Replace encoded characters + decoded = re.sub(r"_([0-9a-f]{2})_", replace_code, encoded_metric_name) + + return decoded + + +def get_training_job_details(job_arn: str) -> dict: + """Retrieve details of a SageMaker training job. + + Args: + job_arn (str): The ARN of the SageMaker training job. + + Returns: + dict: A dictionary containing the details of the training job. + + Raises: + boto3.exceptions.Boto3Error: If there's an issue with the AWS API call. + """ + sagemaker_client = boto3.client("sagemaker") + job_name = job_arn.split("/")[-1] + return sagemaker_client.describe_training_job(TrainingJobName=job_name) + + +def create_metric_queries(job_arn: str, metric_definitions: list) -> list: + """Create metric queries for SageMaker metrics. + + Args: + job_arn (str): The ARN of the SageMaker training job. + metric_definitions (list): List of metric definitions from the training job. + + Returns: + list: A list of dictionaries, each representing a metric query. + """ + metric_queries = [] + for metric in metric_definitions: + query = { + "MetricName": metric["Name"], + "XAxisType": "Timestamp", + "MetricStat": "Avg", + "Period": "OneMinute", + "ResourceArn": job_arn, + } + metric_queries.append(query) + return metric_queries + + +def get_metric_data(metric_queries: list) -> dict: + """Retrieve metric data from SageMaker. + + Args: + metric_queries (list): A list of metric queries. + + Returns: + dict: A dictionary containing the metric data results. + + Raises: + boto3.exceptions.Boto3Error: If there's an issue with the AWS API call. + """ + sagemaker_metrics_client = boto3.client("sagemaker-metrics") + metric_data = sagemaker_metrics_client.batch_get_metrics(MetricQueries=metric_queries) + return metric_data + + +def prepare_mlflow_metrics( + metric_queries: list, metric_results: list +) -> Tuple[List[Metric], Dict[str, str]]: + """Prepare metrics for MLflow logging, encoding metric names if necessary. + + Args: + metric_queries (list): The original metric queries sent to SageMaker. + metric_results (list): The metric results from SageMaker batch_get_metrics. + + Returns: + Tuple[List[Metric], Dict[str, str]]: + - A list of Metric objects with encoded names (if necessary) + - A mapping of encoded to original names for metrics (only for encoded metrics) + """ + mlflow_metrics = [] + metric_name_mapping = {} + existing_names = set() + + for query, result in zip(metric_queries, metric_results): + if result["Status"] == "Complete": + metric_name = query["MetricName"] + encoded_name = encode(metric_name, existing_names) + metric_name_mapping[encoded_name] = metric_name + + for step, (timestamp, value) in enumerate( + zip(result["XAxisValues"], result["MetricValues"]) + ): + metric = Metric(key=encoded_name, value=value, timestamp=timestamp, step=step) + mlflow_metrics.append(metric) + + return mlflow_metrics, metric_name_mapping + + +def prepare_mlflow_params(hyperparameters: Dict[str, str]) -> Tuple[List[Param], Dict[str, str]]: + """Prepare hyperparameters for MLflow logging, encoding parameter names if necessary. + + Args: + hyperparameters (Dict[str, str]): The hyperparameters from the SageMaker job. + + Returns: + Tuple[List[Param], Dict[str, str]]: + - A list of Param objects with encoded names (if necessary) + - A mapping of encoded to original names for + hyperparameters (only for encoded parameters) + """ + mlflow_params = [] + param_name_mapping = {} + existing_names = set() + + for key, value in hyperparameters.items(): + encoded_key = encode(key, existing_names) + param_name_mapping[encoded_key] = key + mlflow_params.append(Param(encoded_key, str(value))) + + return mlflow_params, param_name_mapping + + +def batch_items(items: list, batch_size: int) -> Generator: + """Yield successive batch_size chunks from items. + + Args: + items (list): The list of items to be batched. + batch_size (int): The size of each batch. + + Yields: + list: A batch of items. + """ + for i in range(0, len(items), batch_size): + yield items[i : i + batch_size] + + +def log_to_mlflow(metrics: list, params: list, tags: dict) -> None: + """Log metrics, parameters, and tags to MLflow. + + Args: + metrics (list): List of metrics to log. + params (list): List of parameters to log. + tags (dict): Dictionary of tags to set. + + Raises: + mlflow.exceptions.MlflowException: If there's an issue with MLflow logging. + """ + client = MlflowClient() + + experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME") + if experiment_name is None or experiment_name.strip() == "": + experiment_name = "Default" + print("MLFLOW_EXPERIMENT_NAME not set. Using Default") + + experiment = client.get_experiment_by_name(experiment_name) + if experiment is None: + experiment_id = client.create_experiment(experiment_name) + else: + experiment_id = experiment.experiment_id + + run = client.create_run(experiment_id) + + for metric_batch in batch_items(metrics, 1000): + client.log_batch( + run.info.run_id, + metrics=metric_batch, + ) + for param_batch in batch_items(params, 1000): + client.log_batch(run.info.run_id, params=param_batch) + + tag_items = list(tags.items()) + for tag_batch in batch_items(tag_items, 1000): + tag_objects = [RunTag(key, str(value)) for key, value in tag_batch] + client.log_batch(run.info.run_id, tags=tag_objects) + client.set_terminated(run.info.run_id) + + +def log_sagemaker_job_to_mlflow(training_job_arn: str) -> None: + """Retrieve SageMaker metrics and hyperparameters and log them to MLflow. + + Args: + training_job_arn (str): The ARN of the SageMaker training job. + + Raises: + Exception: If there's any error during the process. + """ + # Get training job details + mlflow.set_tracking_uri(os.getenv("MLFLOW_TRACKING_URI")) + job_details = get_training_job_details(training_job_arn) + + # Extract hyperparameters and metric definitions + hyperparameters = job_details["HyperParameters"] + metric_definitions = job_details["AlgorithmSpecification"]["MetricDefinitions"] + + # Create and get metric queries + metric_queries = create_metric_queries(job_details["TrainingJobArn"], metric_definitions) + metric_data = get_metric_data(metric_queries) + + # Create a mapping of encoded to original metric names + # Prepare data for MLflow + mlflow_metrics, metric_name_mapping = prepare_mlflow_metrics( + metric_queries, metric_data["MetricQueryResults"] + ) + + # Create a mapping of encoded to original hyperparameter names + # Prepare data for MLflow + mlflow_params, param_name_mapping = prepare_mlflow_params(hyperparameters) + + mlflow_tags = { + "training_job_arn": training_job_arn, + "metric_name_mapping": str(metric_name_mapping), + "param_name_mapping": str(param_name_mapping), + } + + # Log to MLflow + log_to_mlflow(mlflow_metrics, mlflow_params, mlflow_tags) + print(f"Logged {len(mlflow_metrics)} metric datapoints to MLflow") + print(f"Logged {len(mlflow_params)} hyperparameters to MLflow") diff --git a/tests/unit/sagemaker/mlflow/test_forward_sagemaker_metrics.py b/tests/unit/sagemaker/mlflow/test_forward_sagemaker_metrics.py new file mode 100644 index 0000000000..4b53c93ad4 --- /dev/null +++ b/tests/unit/sagemaker/mlflow/test_forward_sagemaker_metrics.py @@ -0,0 +1,272 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +from __future__ import absolute_import +from unittest.mock import patch, MagicMock, Mock +import json +import pytest +from mlflow.entities import Metric, Param +import requests + + +from sagemaker.mlflow.forward_sagemaker_metrics import ( + encode, + log_sagemaker_job_to_mlflow, + decode, + prepare_mlflow_metrics, + prepare_mlflow_params, + batch_items, + create_metric_queries, + get_metric_data, + log_to_mlflow, + get_training_job_details, +) + + +@pytest.fixture +def mock_boto3_client(): + with patch("boto3.client") as mock_client: + yield mock_client + + +@pytest.fixture +def mock_mlflow_client(): + with patch("mlflow.MlflowClient") as mock_client: + yield mock_client + + +def test_encode(): + existing_names = set() + assert encode("test-name", existing_names) == "test-name" + assert encode("test:name", existing_names) == "test_3a_name" + assert encode("test-name", existing_names) == "test-name_1" + + +def test_encode_colon_allowed(): + # Test case where colon is allowed (Unix-like system and MLflow >= 2.16.0) + with patch("platform.system") as mock_system, patch("mlflow.__version__", new="2.16.0"): + + mock_system.return_value = "Darwin" # MacOS + existing_names = set() + + assert encode("test:name", existing_names) == "test:name" + assert encode("test/name", existing_names) == "test/name" + assert encode("test name", existing_names) == "test name" + assert encode("test@name", existing_names) == "test_40_name" + + # Test name longer than 250 characters + long_name = "a" * 250 + encoded_long_name = encode(long_name, existing_names) + assert len(encoded_long_name) == 250 + assert encoded_long_name == "a" * 250 + + # Test suffix addition for duplicate names + assert encode("duplicate", existing_names) == "duplicate" + assert encode("duplicate", existing_names) == "duplicate_1" + assert encode("duplicate", existing_names) == "duplicate_2" + + +def test_decode(): + assert decode("test_3a_name") == "test:name" + assert decode("normal_name") == "normal_name" + + +def test_get_training_job_details(mock_boto3_client): + mock_sagemaker = MagicMock() + mock_boto3_client.return_value = mock_sagemaker + mock_sagemaker.describe_training_job.return_value = {"JobName": "test-job"} + + result = get_training_job_details( + "arn:aws:sagemaker:us-west-2:123456789012:training-job/test-job" + ) + assert result == {"JobName": "test-job"} + mock_sagemaker.describe_training_job.assert_called_once_with(TrainingJobName="test-job") + + +def test_create_metric_queries(): + job_arn = "arn:aws:sagemaker:us-west-2:123456789012:training-job/test-job" + metric_definitions = [{"Name": "loss"}, {"Name": "accuracy"}] + result = create_metric_queries(job_arn, metric_definitions) + assert len(result) == 2 + assert result[0]["MetricName"] == "loss" + assert result[1]["MetricName"] == "accuracy" + + +def test_get_metric_data(mock_boto3_client): + mock_metrics = MagicMock() + mock_boto3_client.return_value = mock_metrics + mock_metrics.batch_get_metrics.return_value = {"MetricResults": []} + + metric_queries = [{"MetricName": "loss"}] + result = get_metric_data(metric_queries) + assert result == {"MetricResults": []} + mock_metrics.batch_get_metrics.assert_called_once_with(MetricQueries=metric_queries) + + +def test_prepare_mlflow_metrics(): + metric_queries = [{"MetricName": "loss"}, {"MetricName": "accuracy!"}] + metric_results = [ + {"Status": "Complete", "XAxisValues": [1, 2], "MetricValues": [0.1, 0.2]}, + {"Status": "Complete", "XAxisValues": [1, 2], "MetricValues": [0.8, 0.9]}, + ] + expected_encoded = {"loss": "loss", "accuracy_21_": "accuracy!"} + + metrics, mapping = prepare_mlflow_metrics(metric_queries, metric_results) + + assert len(metrics) == sum(len(result["MetricValues"]) for result in metric_results) + + expected_metrics = [ + ("loss", 0.1, 1, 0), + ("loss", 0.2, 2, 1), + ("accuracy_21_", 0.8, 1, 0), + ("accuracy_21_", 0.9, 2, 1), + ] + + for metric, (exp_key, exp_value, exp_timestamp, exp_step) in zip(metrics, expected_metrics): + assert metric.key == exp_key + assert metric.value == exp_value + assert metric.timestamp == exp_timestamp + assert metric.step == exp_step + + assert mapping == {v: k for v, k in expected_encoded.items()} + + +def test_prepare_mlflow_params(): + hyperparameters = {"learning_rate": "0.01", "batch_!size": "32"} + expected_encoded = {"learning_rate": "learning_rate", "batch__21_size": "batch_!size"} + + params, mapping = prepare_mlflow_params(hyperparameters) + + assert len(params) == len(hyperparameters) + + for param in params: + assert param.key in expected_encoded + assert param.value == hyperparameters[mapping[param.key]] + + assert mapping == {v: k for v, k in expected_encoded.items()} + + +def test_batch_items(): + items = [1, 2, 3, 4, 5] + batches = list(batch_items(items, 2)) + assert batches == [[1, 2], [3, 4], [5]] + + +@patch("os.getenv") +@patch("requests.Session.request") +def test_log_to_mlflow(mock_request, mock_getenv): + # Set up return values for os.getenv calls + def getenv_side_effect(arg, default=None): + values = { + "MLFLOW_TRACKING_URI": "https://test.sagemaker.aws", + "MLFLOW_REGISTRY_URI": "https://registry.uri", + "MLFLOW_EXPERIMENT_NAME": "test_experiment", + "MLFLOW_ALLOW_HTTP_REDIRECTS": "true", + } + return values.get(arg, default) + + mock_getenv.side_effect = getenv_side_effect + + # Mock the HTTP requests + mock_responses = { + "https://test.sagemaker.aws/api/2.0/mlflow/experiments/get-by-name": Mock( + spec=requests.Response + ), + "https://test.sagemaker.aws/api/2.0/mlflow/runs/create": Mock(spec=requests.Response), + "https://test.sagemaker.aws/api/2.0/mlflow/runs/log-batch": [ + Mock(spec=requests.Response), + Mock(spec=requests.Response), + Mock(spec=requests.Response), + ], + "https://test.sagemaker.aws/api/2.0/mlflow/runs/terminate": Mock(spec=requests.Response), + } + + mock_responses[ + "https://test.sagemaker.aws/api/2.0/mlflow/experiments/get-by-name" + ].status_code = 200 + mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/experiments/get-by-name"].text = ( + json.dumps( + { + "experiment_id": "existing_experiment_id", + "name": "test_experiment", + "artifact_location": "some/path", + "lifecycle_stage": "active", + "tags": {}, + } + ) + ) + + mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/runs/create"].status_code = 200 + mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/runs/create"].text = json.dumps( + {"run_id": "test_run_id"} + ) + + for mock_response in mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/runs/log-batch"]: + mock_response.status_code = 200 + mock_response.text = json.dumps({}) + + mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/runs/terminate"].status_code = 200 + mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/runs/terminate"].text = json.dumps({}) + + mock_request.side_effect = [ + mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/experiments/get-by-name"], + mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/runs/create"], + *mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/runs/log-batch"], + mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/runs/terminate"], + ] + + metrics = [Metric("loss", 0.1, 1, 0)] + params = [Param("learning_rate", "0.01")] + tags = {"tag1": "value1"} + + log_to_mlflow(metrics, params, tags) + + assert mock_request.call_count == 6 # Total number of API calls + + +@patch("sagemaker.mlflow.forward_sagemaker_metrics.get_training_job_details") +@patch("sagemaker.mlflow.forward_sagemaker_metrics.create_metric_queries") +@patch("sagemaker.mlflow.forward_sagemaker_metrics.get_metric_data") +@patch("sagemaker.mlflow.forward_sagemaker_metrics.prepare_mlflow_metrics") +@patch("sagemaker.mlflow.forward_sagemaker_metrics.prepare_mlflow_params") +@patch("sagemaker.mlflow.forward_sagemaker_metrics.log_to_mlflow") +def test_log_sagemaker_job_to_mlflow( + mock_log_to_mlflow, + mock_prepare_params, + mock_prepare_metrics, + mock_get_metric_data, + mock_create_queries, + mock_get_job_details, +): + mock_get_job_details.return_value = { + "HyperParameters": {"learning_rate": "0.01"}, + "AlgorithmSpecification": {"MetricDefinitions": [{"Name": "loss"}]}, + "TrainingJobArn": "arn:aws:sagemaker:us-west-2:123456789012:training-job/test-job", + } + mock_create_queries.return_value = [{"MetricName": "loss"}] + mock_get_metric_data.return_value = {"MetricQueryResults": []} + mock_prepare_metrics.return_value = ([], {}) + mock_prepare_params.return_value = ([], {}) + + log_sagemaker_job_to_mlflow("test-job") + + mock_get_job_details.assert_called_once() + mock_create_queries.assert_called_once() + mock_get_metric_data.assert_called_once() + mock_prepare_metrics.assert_called_once() + mock_prepare_params.assert_called_once() + mock_log_to_mlflow.assert_called_once() + + +if __name__ == "__main__": + pytest.main() diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 73006ae7cd..628d124716 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -5922,3 +5922,38 @@ def test_estimator_get_app_url_fail(sagemaker_session): f.get_app_url("fake-app") assert "does not support URL retrieval." in str(error) + + +@patch("sagemaker.estimator.log_sagemaker_job_to_mlflow") +def test_forward_sagemaker_metrics(mock_log_to_mlflow, sagemaker_session): + f = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + enable_network_isolation=True, + sagemaker_session=sagemaker_session, + instance_groups=[ + InstanceGroup("group1", "ml.c4.xlarge", 1), + ], + ) + + # Set environment variables restores to state after the test. + with patch.dict(os.environ, {"MLFLOW_TRACKING_URI": "test_uri"}): + f.fit("s3://mydata") + + mock_log_to_mlflow.assert_called_once() + + +@patch("sagemaker.estimator.log_sagemaker_job_to_mlflow") +def test_no_forward_sagemaker_metrics(mock_log_to_mlflow, sagemaker_session): + f = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + enable_network_isolation=False, + instance_groups=[ + InstanceGroup("group1", "ml.c4.xlarge", 1), + ], + ) + with patch.dict(os.environ, {"MLFLOW_TRACKING_URI": "test_uri"}): + f.fit("s3://mydata") + mock_log_to_mlflow.assert_not_called()