Skip to content
This repository was archived by the owner on Feb 18, 2026. It is now read-only.

Commit 25aeaa3

Browse files
committed
Replace all trainer smart_open+get_artifact_link with get_artifact_handle
GitOrigin-RevId: d192abacedea7430175f4e9920530cc80c32feb6
1 parent e037974 commit 25aeaa3

File tree

6 files changed

+25
-27
lines changed

6 files changed

+25
-27
lines changed

src/gretel_trainer/benchmark/sdk_extras.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33

44
from typing import Any
55

6-
import smart_open
7-
86
from gretel_client.projects.jobs import (
97
ACTIVE_STATES,
108
END_STATES,
@@ -41,9 +39,8 @@ def run_evaluate(
4139
job_status = await_job(run_identifier, evaluate_model, "evaluation", wait)
4240
if job_status in END_STATES and job_status != Status.COMPLETED:
4341
raise BenchmarkException("Evaluate failed")
44-
return json.loads(
45-
smart_open.open(evaluate_model.get_artifact_link("report_json")).read()
46-
)
42+
with evaluate_model.get_artifact_handle("report_json") as report:
43+
return json.loads(report.read())
4744

4845

4946
def _make_evaluate_config(run_identifier: str) -> dict:

src/gretel_trainer/relational/strategies/common.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,8 @@ def read_report_json_data(model: Model, report_path: Path) -> Optional[dict]:
4040

4141
def _get_report_json(model: Model) -> Optional[dict]:
4242
try:
43-
return json.loads(
44-
smart_open.open(model.get_artifact_link("report_json")).read()
45-
)
43+
with model.get_artifact_handle("report_json") as report:
44+
return json.loads(report.read())
4645
except:
4746
logger.warning("Failed to fetch model evaluation report JSON.")
4847
return None

src/gretel_trainer/relational/tasks/classify.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,8 @@ def _write_results(self, job: Job, table: str) -> None:
151151

152152
destpath = self.out_dir / filename
153153

154-
with smart_open.open(
155-
job.get_artifact_link(artifact_name), "rb"
156-
) as src, smart_open.open(str(destpath), "wb") as dest:
154+
with job.get_artifact_handle(artifact_name) as src, smart_open.open(
155+
str(destpath), "wb"
156+
) as dest:
157157
shutil.copyfileobj(src, dest)
158158
self.result_filepaths.append(destpath)

src/gretel_trainer/runner.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
from typing import List, Optional, Tuple, Union
3939

4040
import pandas as pd
41-
import smart_open
4241

4342
from gretel_client.projects import Project
4443
from gretel_client.projects.jobs import ACTIVE_STATES
@@ -213,9 +212,7 @@ def _update_job_status(self):
213212
report = current_model.peek_report()
214213

215214
if report is None:
216-
with smart_open.open(
217-
current_model.get_artifact_link("report_json")
218-
) as fin:
215+
with current_model.get_artifact_handle("report_json") as fin:
219216
report = json.loads(fin.read())
220217

221218
sqs = report["synthetic_data_quality_score"]["score"]

tests/benchmark/conftest.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,17 @@ def project():
4848

4949

5050
@pytest.fixture()
51-
def evaluate_report_path():
51+
def evaluate_report_handle():
5252
report = {"synthetic_data_quality_score": {"score": 95}}
5353
with tempfile.NamedTemporaryFile() as f:
5454
with open(f.name, "w") as j:
5555
json.dump(report, j)
56-
yield f.name
56+
57+
ctxmgr = Mock()
58+
ctxmgr.__enter__ = Mock(return_value=f)
59+
ctxmgr.__exit__ = Mock(return_value=False)
60+
61+
yield ctxmgr
5762

5863

5964
@pytest.fixture()

tests/benchmark/test_benchmark.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,11 @@ class SharedDictLstm(GretelModel):
8181
}
8282

8383

84-
def test_run_with_gretel_dataset(working_dir, project, evaluate_report_path, iris):
84+
def test_run_with_gretel_dataset(working_dir, project, evaluate_report_handle, iris):
8585
evaluate_model = Mock(
8686
status=Status.COMPLETED,
8787
)
88-
evaluate_model.get_artifact_link.return_value = evaluate_report_path
88+
evaluate_model.get_artifact_handle.return_value = evaluate_report_handle
8989
project.create_model_obj.side_effect = [evaluate_model]
9090

9191
session = compare(
@@ -107,11 +107,11 @@ def test_run_with_gretel_dataset(working_dir, project, evaluate_report_path, iri
107107
assert result["SQS"] == 95
108108

109109

110-
def test_run_with_custom_csv_dataset(working_dir, project, evaluate_report_path, df):
110+
def test_run_with_custom_csv_dataset(working_dir, project, evaluate_report_handle, df):
111111
evaluate_model = Mock(
112112
status=Status.COMPLETED,
113113
)
114-
evaluate_model.get_artifact_link.return_value = evaluate_report_path
114+
evaluate_model.get_artifact_handle.return_value = evaluate_report_handle
115115
project.create_model_obj.side_effect = [evaluate_model]
116116

117117
with tempfile.NamedTemporaryFile() as f:
@@ -137,11 +137,11 @@ def test_run_with_custom_csv_dataset(working_dir, project, evaluate_report_path,
137137
assert result["SQS"] == 95
138138

139139

140-
def test_run_with_custom_psv_dataset(working_dir, project, evaluate_report_path, df):
140+
def test_run_with_custom_psv_dataset(working_dir, project, evaluate_report_handle, df):
141141
evaluate_model = Mock(
142142
status=Status.COMPLETED,
143143
)
144-
evaluate_model.get_artifact_link.return_value = evaluate_report_path
144+
evaluate_model.get_artifact_handle.return_value = evaluate_report_handle
145145
project.create_model_obj.side_effect = [evaluate_model]
146146

147147
with tempfile.NamedTemporaryFile() as f:
@@ -168,12 +168,12 @@ def test_run_with_custom_psv_dataset(working_dir, project, evaluate_report_path,
168168

169169

170170
def test_run_with_custom_dataframe_dataset(
171-
working_dir, project, evaluate_report_path, df
171+
working_dir, project, evaluate_report_handle, df
172172
):
173173
evaluate_model = Mock(
174174
status=Status.COMPLETED,
175175
)
176-
evaluate_model.get_artifact_link.return_value = evaluate_report_path
176+
evaluate_model.get_artifact_handle.return_value = evaluate_report_handle
177177
project.create_model_obj.side_effect = [evaluate_model]
178178

179179
dataset = create_dataset(df, datatype="tabular", name="pets")
@@ -205,7 +205,7 @@ def test_run_with_custom_dataframe_dataset(
205205

206206
@pytest.mark.parametrize("benchmark_model", [GretelLSTM, TailoredActgan])
207207
def test_run_happy_path_gretel_sdk(
208-
benchmark_model, working_dir, iris, project, evaluate_report_path
208+
benchmark_model, working_dir, iris, project, evaluate_report_handle
209209
):
210210
record_handler = Mock(
211211
status=Status.COMPLETED,
@@ -221,7 +221,7 @@ def test_run_happy_path_gretel_sdk(
221221
evaluate_model = Mock(
222222
status=Status.COMPLETED,
223223
)
224-
evaluate_model.get_artifact_link.return_value = evaluate_report_path
224+
evaluate_model.get_artifact_handle.return_value = evaluate_report_handle
225225

226226
project.create_model_obj.side_effect = [model, evaluate_model]
227227

0 commit comments

Comments
 (0)