Skip to content

Commit 41cad2d

Browse files
author
Ashish Gupta
committed
Revert "fix: Move sagemaker-mlflow to extras (aws#4903)"
This reverts commit 292a00d.
1 parent 0275a2c commit 41cad2d

File tree

7 files changed

+8
-18
lines changed

7 files changed

+8
-18
lines changed

hatch_build.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def read_feature_deps(feature):
2020

2121
optional_dependencies = {"all": []}
2222

23-
for feature in ("feature-processor", "huggingface", "local", "scipy", "sagemaker-mlflow"):
23+
for feature in ("feature-processor", "huggingface", "local", "scipy"):
2424
dependencies = read_feature_deps(feature)
2525
optional_dependencies[feature] = dependencies
2626
optional_dependencies["all"].extend(dependencies)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ dependencies = [
4848
"PyYAML~=6.0",
4949
"requests",
5050
"sagemaker-core>=1.0.0,<2.0.0",
51+
"sagemaker-mlflow",
5152
"schema",
5253
"smdebug_rulesconfig==1.0.1",
5354
"tblib>=1.7.0,<4",

requirements/extras/sagemaker-mlflow_requirements.txt

Lines changed: 0 additions & 1 deletion
This file was deleted.

requirements/extras/test_requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,3 @@ huggingface_hub>=0.23.4
4444
uvicorn>=0.30.1
4545
fastapi>=0.111.0
4646
nest-asyncio
47-
sagemaker-mlflow>=0.1.0

src/sagemaker/estimator.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@
107107
from sagemaker.workflow.parameters import ParameterString
108108
from sagemaker.workflow.pipeline_context import PipelineSession, runnable_by_pipeline
109109

110+
from sagemaker.mlflow.forward_sagemaker_metrics import log_sagemaker_job_to_mlflow
110111

111112
logger = logging.getLogger(__name__)
112113

@@ -1373,14 +1374,8 @@ def fit(
13731374
forward_to_mlflow_tracking_server = True
13741375
if wait:
13751376
self.latest_training_job.wait(logs=logs)
1376-
try:
1377-
if forward_to_mlflow_tracking_server:
1378-
from sagemaker.mlflow.forward_sagemaker_metrics import log_sagemaker_job_to_mlflow
1379-
1380-
log_sagemaker_job_to_mlflow(self.latest_training_job.name)
1381-
except ImportError:
1382-
if forward_to_mlflow_tracking_server:
1383-
raise ValueError("Unable to import mlflow, check if sagemaker-mlflow is installed")
1377+
if forward_to_mlflow_tracking_server:
1378+
log_sagemaker_job_to_mlflow(self.latest_training_job.name)
13841379

13851380
def _compilation_job_name(self):
13861381
"""Placeholder docstring"""

src/sagemaker/mlflow/forward_sagemaker_metrics.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,7 @@
2020
import re
2121
from typing import Set, Tuple, List, Dict, Generator
2222
import boto3
23-
24-
try:
25-
import mlflow
26-
except ImportError:
27-
raise ValueError("Unable to import mlflow, check if sagemaker-mlflow is installed.")
23+
import mlflow
2824
from mlflow import MlflowClient
2925
from mlflow.entities import Metric, Param, RunTag
3026

tests/unit/test_estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5924,7 +5924,7 @@ def test_estimator_get_app_url_fail(sagemaker_session):
59245924
assert "does not support URL retrieval." in str(error)
59255925

59265926

5927-
@patch("sagemaker.mlflow.forward_sagemaker_metrics.log_sagemaker_job_to_mlflow")
5927+
@patch("sagemaker.estimator.log_sagemaker_job_to_mlflow")
59285928
def test_forward_sagemaker_metrics(mock_log_to_mlflow, sagemaker_session):
59295929
f = DummyFramework(
59305930
entry_point=SCRIPT_PATH,
@@ -5943,7 +5943,7 @@ def test_forward_sagemaker_metrics(mock_log_to_mlflow, sagemaker_session):
59435943
mock_log_to_mlflow.assert_called_once()
59445944

59455945

5946-
@patch("sagemaker.mlflow.forward_sagemaker_metrics.log_sagemaker_job_to_mlflow")
5946+
@patch("sagemaker.estimator.log_sagemaker_job_to_mlflow")
59475947
def test_no_forward_sagemaker_metrics(mock_log_to_mlflow, sagemaker_session):
59485948
f = DummyFramework(
59495949
entry_point=SCRIPT_PATH,

0 commit comments

Comments
 (0)