diff --git a/hatch_build.py b/hatch_build.py index fd428aa1d8..fc75584f17 100644 --- a/hatch_build.py +++ b/hatch_build.py @@ -20,7 +20,7 @@ def read_feature_deps(feature): optional_dependencies = {"all": []} - for feature in ("feature-processor", "huggingface", "local", "scipy"): + for feature in ("feature-processor", "huggingface", "local", "scipy", "sagemaker-mlflow"): dependencies = read_feature_deps(feature) optional_dependencies[feature] = dependencies optional_dependencies["all"].extend(dependencies) diff --git a/pyproject.toml b/pyproject.toml index 80dec58b99..1c60e786e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,6 @@ dependencies = [ "PyYAML~=6.0", "requests", "sagemaker-core>=1.0.0,<2.0.0", - "sagemaker-mlflow", "schema", "smdebug_rulesconfig==1.0.1", "tblib>=1.7.0,<4", diff --git a/requirements/extras/sagemaker-mlflow_requirements.txt b/requirements/extras/sagemaker-mlflow_requirements.txt new file mode 100644 index 0000000000..75f330b0e6 --- /dev/null +++ b/requirements/extras/sagemaker-mlflow_requirements.txt @@ -0,0 +1 @@ +sagemaker-mlflow>=0.1.0 diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index f08a26811e..a2c0fbfc65 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -44,3 +44,4 @@ huggingface_hub>=0.23.4 uvicorn>=0.30.1 fastapi>=0.111.0 nest-asyncio +sagemaker-mlflow>=0.1.0 diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 2aba19b112..2fbcf4373b 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -107,7 +107,6 @@ from sagemaker.workflow.parameters import ParameterString from sagemaker.workflow.pipeline_context import PipelineSession, runnable_by_pipeline -from sagemaker.mlflow.forward_sagemaker_metrics import log_sagemaker_job_to_mlflow logger = logging.getLogger(__name__) @@ -1374,8 +1373,14 @@ def fit( forward_to_mlflow_tracking_server = True if wait: self.latest_training_job.wait(logs=logs) - if forward_to_mlflow_tracking_server: - log_sagemaker_job_to_mlflow(self.latest_training_job.name) + try: + if forward_to_mlflow_tracking_server: + from sagemaker.mlflow.forward_sagemaker_metrics import log_sagemaker_job_to_mlflow + + log_sagemaker_job_to_mlflow(self.latest_training_job.name) + except ImportError: + if forward_to_mlflow_tracking_server: + raise ValueError("Unable to import mlflow, check if sagemaker-mlflow is installed") def _compilation_job_name(self): """Placeholder docstring""" diff --git a/src/sagemaker/mlflow/forward_sagemaker_metrics.py b/src/sagemaker/mlflow/forward_sagemaker_metrics.py index a025241d6a..48b217482c 100644 --- a/src/sagemaker/mlflow/forward_sagemaker_metrics.py +++ b/src/sagemaker/mlflow/forward_sagemaker_metrics.py @@ -20,7 +20,11 @@ import re from typing import Set, Tuple, List, Dict, Generator import boto3 -import mlflow + +try: + import mlflow +except ImportError: + raise ValueError("Unable to import mlflow, check if sagemaker-mlflow is installed.") from mlflow import MlflowClient from mlflow.entities import Metric, Param, RunTag diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 628d124716..31afaa0e7e 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -5924,7 +5924,7 @@ def test_estimator_get_app_url_fail(sagemaker_session): assert "does not support URL retrieval." in str(error) -@patch("sagemaker.estimator.log_sagemaker_job_to_mlflow") +@patch("sagemaker.mlflow.forward_sagemaker_metrics.log_sagemaker_job_to_mlflow") def test_forward_sagemaker_metrics(mock_log_to_mlflow, sagemaker_session): f = DummyFramework( entry_point=SCRIPT_PATH, @@ -5943,7 +5943,7 @@ def test_forward_sagemaker_metrics(mock_log_to_mlflow, sagemaker_session): mock_log_to_mlflow.assert_called_once() -@patch("sagemaker.estimator.log_sagemaker_job_to_mlflow") +@patch("sagemaker.mlflow.forward_sagemaker_metrics.log_sagemaker_job_to_mlflow") def test_no_forward_sagemaker_metrics(mock_log_to_mlflow, sagemaker_session): f = DummyFramework( entry_point=SCRIPT_PATH,