1212# language governing permissions and limitations under the License.
1313"""Holds the ModelBuilder class and the ModelServer enum."""
1414from __future__ import absolute_import
15+
16+ import importlib .util
1517import uuid
1618from typing import Any , Type , List , Dict , Optional , Union
1719from dataclasses import dataclass , field
1820import logging
1921import os
22+ import re
2023
2124from pathlib import Path
2225
4346from sagemaker .predictor import Predictor
4447from sagemaker .serve .model_format .mlflow .constants import (
4548 MLFLOW_MODEL_PATH ,
49+ MLFLOW_TRACKING_ARN ,
50+ MLFLOW_RUN_ID_REGEX ,
51+ MLFLOW_REGISTRY_PATH_REGEX ,
52+ MODEL_PACKAGE_ARN_REGEX ,
4653 MLFLOW_METADATA_FILE ,
4754 MLFLOW_PIP_DEPENDENCY_FILE ,
4855)
4956from sagemaker .serve .model_format .mlflow .utils import (
5057 _get_default_model_server_for_mlflow ,
51- _mlflow_input_is_local_path ,
5258 _download_s3_artifacts ,
5359 _select_container_for_mlflow_model ,
5460 _generate_mlflow_artifact_path ,
@@ -276,8 +282,9 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
276282 default = None ,
277283 metadata = {
278284 "help" : "Define the model metadata to override, currently supports `HF_TASK`, "
279- "`MLFLOW_MODEL_PATH`. HF_TASK should be set for new models without task metadata in "
280- "the Hub, Adding unsupported task types will throw an exception"
285+ "`MLFLOW_MODEL_PATH`, and `MLFLOW_TRACKING_ARN`. HF_TASK should be set for new "
286+ "models without task metadata in the Hub, Adding unsupported task types will "
287+ "throw an exception"
281288 },
282289 )
283290
@@ -502,6 +509,7 @@ def _model_builder_register_wrapper(self, *args, **kwargs):
502509 mlflow_model_path = self .model_metadata [MLFLOW_MODEL_PATH ],
503510 s3_upload_path = self .s3_upload_path ,
504511 sagemaker_session = self .sagemaker_session ,
512+ tracking_server_arn = self .model_metadata .get (MLFLOW_TRACKING_ARN ),
505513 )
506514 return new_model_package
507515
@@ -572,6 +580,7 @@ def _model_builder_deploy_wrapper(
572580 mlflow_model_path = self .model_metadata [MLFLOW_MODEL_PATH ],
573581 s3_upload_path = self .s3_upload_path ,
574582 sagemaker_session = self .sagemaker_session ,
583+ tracking_server_arn = self .model_metadata .get (MLFLOW_TRACKING_ARN ),
575584 )
576585 return predictor
577586
@@ -625,11 +634,30 @@ def wrapper(*args, **kwargs):
625634
626635 return wrapper
627636
628- def _check_if_input_is_mlflow_model (self ) -> bool :
629- """Checks whether an MLmodel file exists in the given directory.
637+ def _handle_mlflow_input (self ):
638+ """Check whether an MLflow model is present and handle accordingly"""
639+ self ._is_mlflow_model = self ._has_mlflow_arguments ()
640+ if not self ._is_mlflow_model :
641+ return
642+
643+ mlflow_model_path = self .model_metadata .get (MLFLOW_MODEL_PATH )
644+ artifact_path = self ._get_artifact_path (mlflow_model_path )
645+ if not self ._mlflow_metadata_exists (artifact_path ):
646+ logger .info (
647+ "MLflow model metadata not detected in %s. ModelBuilder is not "
648+ "handling MLflow model input" ,
649+ mlflow_model_path ,
650+ )
651+ return
652+
653+ self ._initialize_for_mlflow (artifact_path )
654+ _validate_input_for_mlflow (self .model_server , self .env_vars .get ("MLFLOW_MODEL_FLAVOR" ))
655+
656+ def _has_mlflow_arguments (self ) -> bool :
657+ """Check whether MLflow model arguments are present
630658
631659 Returns:
632- bool: True if the MLmodel file exists , False otherwise.
660+ bool: True if MLflow arguments are present , False otherwise.
633661 """
634662 if self .inference_spec or self .model :
635663 logger .info (
@@ -644,16 +672,82 @@ def _check_if_input_is_mlflow_model(self) -> bool:
644672 )
645673 return False
646674
647- path = self .model_metadata .get (MLFLOW_MODEL_PATH )
648- if not path :
675+ mlflow_model_path = self .model_metadata .get (MLFLOW_MODEL_PATH )
676+ if not mlflow_model_path :
649677 logger .info (
650678 "%s is not provided in ModelMetadata. ModelBuilder is not handling MLflow model "
651679 "input" ,
652680 MLFLOW_MODEL_PATH ,
653681 )
654682 return False
655683
656- # Check for S3 path
684+ return True
685+
686+ def _get_artifact_path (self , mlflow_model_path : str ) -> str :
687+ """Retrieves the model artifact location given the Mlflow model input.
688+
689+ Args:
690+ mlflow_model_path (str): The MLflow model path input.
691+
692+ Returns:
693+ str: The path to the model artifact.
694+ """
695+ if (is_run_id_type := re .match (MLFLOW_RUN_ID_REGEX , mlflow_model_path )) or re .match (
696+ MLFLOW_REGISTRY_PATH_REGEX , mlflow_model_path
697+ ):
698+ mlflow_tracking_arn = self .model_metadata .get (MLFLOW_TRACKING_ARN )
699+ if not mlflow_tracking_arn :
700+ raise ValueError (
701+ "%s is not provided in ModelMetadata or through set_tracking_arn "
702+ "but MLflow model path was provided." % MLFLOW_TRACKING_ARN ,
703+ )
704+
705+ if not importlib .util .find_spec ("sagemaker_mlflow" ):
706+ raise ImportError (
707+ "Unable to import sagemaker_mlflow, check if sagemaker_mlflow is installed"
708+ )
709+
710+ import mlflow
711+
712+ mlflow .set_tracking_uri (mlflow_tracking_arn )
713+ if is_run_id_type :
714+ _ , run_id , model_path = mlflow_model_path .split ("/" , 2 )
715+ artifact_uri = mlflow .get_run (run_id ).info .artifact_uri
716+ if not artifact_uri .endswith ("/" ):
717+ artifact_uri += "/"
718+ return artifact_uri + model_path
719+
720+ mlflow_client = mlflow .MlflowClient ()
721+ if not mlflow_model_path .endswith ("/" ):
722+ mlflow_model_path += "/"
723+
724+ if "@" in mlflow_model_path :
725+ _ , model_name_and_alias , artifact_uri = mlflow_model_path .split ("/" , 2 )
726+ model_name , model_alias = model_name_and_alias .split ("@" )
727+ model_metadata = mlflow_client .get_model_version_by_alias (model_name , model_alias )
728+ else :
729+ _ , model_name , model_version , artifact_uri = mlflow_model_path .split ("/" , 3 )
730+ model_metadata = mlflow_client .get_model_version (model_name , model_version )
731+
732+ source = model_metadata .source
733+ if not source .endswith ("/" ):
734+ source += "/"
735+ return source + artifact_uri
736+
737+ if re .match (MODEL_PACKAGE_ARN_REGEX , mlflow_model_path ):
738+ model_package = self .sagemaker_session .sagemaker_client .describe_model_package (
739+ ModelPackageName = mlflow_model_path
740+ )
741+ return model_package ["SourceUri" ]
742+
743+ return mlflow_model_path
744+
745+ def _mlflow_metadata_exists (self , path : str ) -> bool :
746+ """Checks whether an MLmodel file exists in the given directory.
747+
748+ Returns:
749+ bool: True if the MLmodel file exists, False otherwise.
750+ """
657751 if path .startswith ("s3://" ):
658752 s3_downloader = S3Downloader ()
659753 if not path .endswith ("/" ):
@@ -665,17 +759,18 @@ def _check_if_input_is_mlflow_model(self) -> bool:
665759 file_path = os .path .join (path , MLFLOW_METADATA_FILE )
666760 return os .path .isfile (file_path )
667761
668- def _initialize_for_mlflow (self ) -> None :
669- """Initialize mlflow model artifacts, image uri and model server."""
670- mlflow_path = self .model_metadata .get (MLFLOW_MODEL_PATH )
671- if not _mlflow_input_is_local_path (mlflow_path ):
672- # TODO: extend to package arn, run id and etc.
673- logger .info (
674- "Start downloading model artifacts from %s to %s" , mlflow_path , self .model_path
675- )
676- _download_s3_artifacts (mlflow_path , self .model_path , self .sagemaker_session )
762+ def _initialize_for_mlflow (self , artifact_path : str ) -> None :
763+ """Initialize mlflow model artifacts, image uri and model server.
764+
765+ Args:
766+ artifact_path (str): The path to the artifact store.
767+ """
768+ if artifact_path .startswith ("s3://" ):
769+ _download_s3_artifacts (artifact_path , self .model_path , self .sagemaker_session )
770+ elif os .path .exists (artifact_path ):
771+ _copy_directory_contents (artifact_path , self .model_path )
677772 else :
678- _copy_directory_contents ( mlflow_path , self . model_path )
773+ raise ValueError ( "Invalid path: %s" % artifact_path )
679774 mlflow_model_metadata_path = _generate_mlflow_artifact_path (
680775 self .model_path , MLFLOW_METADATA_FILE
681776 )
@@ -728,6 +823,8 @@ def build( # pylint: disable=R0911
728823 self .role_arn = role_arn
729824 self .sagemaker_session = sagemaker_session or Session ()
730825
826+ self .sagemaker_session .settings ._local_download_dir = self .model_path
827+
731828 # https://github.com/boto/botocore/blob/develop/botocore/useragent.py#L258
732829 # decorate to_string() due to
733830 # https://github.com/boto/botocore/blob/develop/botocore/client.py#L1014-L1015
@@ -739,14 +836,8 @@ def build( # pylint: disable=R0911
739836 self .serve_settings = self ._get_serve_setting ()
740837
741838 self ._is_custom_image_uri = self .image_uri is not None
742- self ._is_mlflow_model = self ._check_if_input_is_mlflow_model ()
743- if self ._is_mlflow_model :
744- logger .warning (
745- "Support of MLflow format models is experimental and is not intended"
746- " for production at this moment."
747- )
748- self ._initialize_for_mlflow ()
749- _validate_input_for_mlflow (self .model_server , self .env_vars .get ("MLFLOW_MODEL_FLAVOR" ))
839+
840+ self ._handle_mlflow_input ()
750841
751842 if isinstance (self .model , str ):
752843 model_task = None
@@ -836,6 +927,19 @@ def validate(self, model_dir: str) -> Type[bool]:
836927
837928 return get_metadata (model_dir )
838929
930+ def set_tracking_arn (self , arn : str ):
931+ """Set tracking server ARN"""
932+ # TODO: support native MLflow URIs
933+ if importlib .util .find_spec ("sagemaker_mlflow" ):
934+ import mlflow
935+
936+ mlflow .set_tracking_uri (arn )
937+ self .model_metadata [MLFLOW_TRACKING_ARN ] = arn
938+ else :
939+ raise ImportError (
940+ "Unable to import sagemaker_mlflow, check if sagemaker_mlflow is installed"
941+ )
942+
839943 def _hf_schema_builder_init (self , model_task : str ):
840944 """Initialize the schema builder for the given HF_TASK
841945
0 commit comments