diff --git a/src/sagemaker/serve/utils/telemetry_logger.py b/src/sagemaker/serve/utils/telemetry_logger.py index a1a0408718..c02fe9bf78 100644 --- a/src/sagemaker/serve/utils/telemetry_logger.py +++ b/src/sagemaker/serve/utils/telemetry_logger.py @@ -19,7 +19,7 @@ from sagemaker import Session, exceptions from sagemaker.serve.mode.function_pointers import Mode -from sagemaker.serve.model_format.mlflow.constants import MLFLOW_MODEL_PATH +from sagemaker.serve.model_format.mlflow.constants import MLFLOW_MODEL_PATH, MLFLOW_TRACKING_ARN from sagemaker.serve.utils.exceptions import ModelBuilderException from sagemaker.serve.utils.lineage_constants import ( MLFLOW_LOCAL_PATH, @@ -144,6 +144,9 @@ def wrapper(self, *args, **kwargs): mlflow_model_path = self.model_metadata[MLFLOW_MODEL_PATH] mlflow_model_path_type = _get_mlflow_model_path_type(mlflow_model_path) extra += f"&x-mlflowModelPathType={MLFLOW_MODEL_PATH_CODE[mlflow_model_path_type]}" + mlflow_model_tracking_server_arn = self.model_metadata.get(MLFLOW_TRACKING_ARN) + if mlflow_model_tracking_server_arn is not None: + extra += f"&x-mlflowTrackingServerArn={mlflow_model_tracking_server_arn}" if getattr(self, "model_hub", False): extra += f"&x-modelHub={MODEL_HUB_TO_CODE[str(self.model_hub)]}" diff --git a/tests/unit/sagemaker/serve/utils/test_telemetry_logger.py b/tests/unit/sagemaker/serve/utils/test_telemetry_logger.py index 4729efbda4..fc832ad02d 100644 --- a/tests/unit/sagemaker/serve/utils/test_telemetry_logger.py +++ b/tests/unit/sagemaker/serve/utils/test_telemetry_logger.py @@ -14,7 +14,7 @@ import unittest from unittest.mock import Mock, patch, MagicMock from sagemaker.serve import Mode, ModelServer -from sagemaker.serve.model_format.mlflow.constants import MLFLOW_MODEL_PATH +from sagemaker.serve.model_format.mlflow.constants import MLFLOW_MODEL_PATH, MLFLOW_TRACKING_ARN from sagemaker.serve.utils.telemetry_logger import ( _send_telemetry, _capture_telemetry, @@ -40,7 +40,10 @@ MOCK_HUGGINGFACE_ID = "meta-llama/Llama-2-7b-hf" MOCK_EXCEPTION = LocalModelOutOfMemoryException("mock raise ex") MOCK_ENDPOINT_ARN = "arn:aws:sagemaker:us-west-2:123456789012:endpoint/test" -MOCK_MODEL_METADATA_FOR_MLFLOW = {MLFLOW_MODEL_PATH: "s3://some_path"} +MOCK_MODEL_METADATA_FOR_MLFLOW = { + MLFLOW_MODEL_PATH: "s3://some_path", + MLFLOW_TRACKING_ARN: "arn:aws:sagemaker:us-west-2:000000000000:mlflow-tracking-server/test", +} class ModelBuilderMock: @@ -274,6 +277,7 @@ def test_capture_telemetry_decorator_mlflow_success(self, mock_send_telemetry): f"&x-defaultImageUsage={ImageUriOption.DEFAULT_IMAGE.value}" f"&x-endpointArn={MOCK_ENDPOINT_ARN}" f"&x-mlflowModelPathType=2" + f"&x-mlflowTrackingServerArn={MOCK_MODEL_METADATA_FOR_MLFLOW[MLFLOW_TRACKING_ARN]}" f"&x-latency={latency}" )