Skip to content
This repository was archived by the owner on Feb 18, 2026. It is now read-only.

Commit 3cddc7c

Browse files
mckornfieldtylersbray
authored andcommitted
Add azure and google storage deps + wrap smart_open calls for relational
* Add azure and google storage deps * Update relational SDK to use transport params * Wrap smart_open calls to insert the transport_params as necessary for Azure GitOrigin-RevId: e13e70c9308c6a27099c998f82f4e3267463ef95
1 parent 35684de commit 3cddc7c

File tree

7 files changed

+29
-25
lines changed

7 files changed

+29
-25
lines changed

src/gretel_trainer/relational/core.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
import networkx
2727
import pandas as pd
28-
import smart_open
2928

3029
from networkx.algorithms.cycles import simple_cycles
3130
from networkx.algorithms.dag import dag_longest_path_length, topological_sort
@@ -34,6 +33,7 @@
3433

3534
import gretel_trainer.relational.json as relational_json
3635

36+
from gretel_client.projects.artifact_handlers import open_artifact
3737
from gretel_trainer.relational.json import (
3838
IngestResponseT,
3939
InventedTableMetadata,
@@ -273,7 +273,7 @@ def add_table(
273273
preview_df = data.head(PREVIEW_ROW_COUNT)
274274
elif isinstance(data, (str, Path)):
275275
data_location = self.source_data_handler.resolve_data_location(data)
276-
with smart_open.open(data_location, "rb") as d:
276+
with open_artifact(data_location, "rb") as d:
277277
preview_df = pd.read_csv(d, nrows=PREVIEW_ROW_COUNT)
278278
columns = list(preview_df.columns)
279279
json_cols = relational_json.get_json_columns(preview_df)
@@ -293,7 +293,7 @@ def add_table(
293293
if isinstance(data, pd.DataFrame):
294294
df = data
295295
elif isinstance(data, (str, Path)):
296-
with smart_open.open(data, "rb") as d:
296+
with open_artifact(data, "rb") as d:
297297
df = pd.read_csv(d)
298298
rj_ingest = relational_json.ingest(name, primary_key, df, json_cols)
299299

@@ -359,7 +359,7 @@ def _add_single_table(
359359
if columns is not None:
360360
cols = columns
361361
else:
362-
with smart_open.open(source, "rb") as src:
362+
with open_artifact(source, "rb") as src:
363363
cols = list(pd.read_csv(src, nrows=1).columns)
364364
metadata = TableMetadata(
365365
primary_key=primary_key,
@@ -762,7 +762,7 @@ def get_table_data(
762762
"""
763763
source = self.get_table_source(table)
764764
usecols = usecols or self.get_table_columns(table)
765-
with smart_open.open(source, "rb") as src:
765+
with open_artifact(source, "rb") as src:
766766
return pd.read_csv(src, usecols=usecols)
767767

768768
def get_table_columns(self, table: str) -> list[str]:

src/gretel_trainer/relational/multi_table.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from gretel_client.config import get_session_config, RunnerMode
2727
from gretel_client.projects import create_project, get_project, Project
28+
from gretel_client.projects.artifact_handlers import open_artifact
2829
from gretel_client.projects.jobs import ACTIVE_STATES, END_STATES, Status
2930
from gretel_client.projects.records import RecordHandler
3031
from gretel_trainer.relational.artifacts import ArtifactCollection
@@ -651,7 +652,7 @@ def run_transforms(
651652
if isinstance(data_source, pd.DataFrame):
652653
data_source.to_csv(transforms_run_path, index=False)
653654
else:
654-
with smart_open.open(data_source, "rb") as src, smart_open.open(
655+
with open_artifact(data_source, "rb") as src, open_artifact(
655656
transforms_run_path, "wb"
656657
) as dest:
657658
shutil.copyfileobj(src, dest)
@@ -690,7 +691,10 @@ def run_transforms(
690691
for table, df in reshaped_tables.items():
691692
filename = f"transformed_{table}.csv"
692693
out_path = self._output_handler.filepath_for(filename, subdir=run_subdir)
693-
with smart_open.open(out_path, "wb") as dest:
694+
with open_artifact(
695+
out_path,
696+
"wb",
697+
) as dest:
694698
df.to_csv(
695699
dest,
696700
index=False,
@@ -899,7 +903,7 @@ def generate(
899903
synth_csv_path = self._output_handler.filepath_for(
900904
f"synth_{table}.csv", subdir=run_subdir
901905
)
902-
with smart_open.open(synth_csv_path, "wb") as dest:
906+
with open_artifact(synth_csv_path, "wb") as dest:
903907
synth_df.to_csv(
904908
dest,
905909
index=False,
@@ -1042,7 +1046,7 @@ def create_relational_report(self, run_identifier: str, filepath: str) -> None:
10421046
now=datetime.utcnow(),
10431047
run_identifier=run_identifier,
10441048
)
1045-
with smart_open.open(filepath, "w") as report:
1049+
with open_artifact(filepath, "w") as report:
10461050
html_content = ReportRenderer().render(presenter)
10471051
report.write(html_content)
10481052

@@ -1054,8 +1058,8 @@ def _attach_existing_reports(self, run_id: str, table: str) -> None:
10541058
f"synthetics_cross_table_evaluation_{table}.json", subdir=run_id
10551059
)
10561060

1057-
individual_report_json = json.loads(smart_open.open(individual_path).read())
1058-
cross_table_report_json = json.loads(smart_open.open(cross_table_path).read())
1061+
individual_report_json = json.loads(open_artifact(individual_path).read())
1062+
cross_table_report_json = json.loads(open_artifact(cross_table_path).read())
10591063

10601064
self._evaluations[table].individual_report_json = individual_report_json
10611065
self._evaluations[table].cross_table_report_json = cross_table_report_json

src/gretel_trainer/relational/sdk_extras.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88
import pandas as pd
99
import requests
10-
import smart_open
1110

11+
from gretel_client.projects.artifact_handlers import open_artifact
1212
from gretel_client.projects.jobs import Job, Status
1313
from gretel_client.projects.models import Model
1414
from gretel_client.projects.projects import Project
@@ -57,9 +57,9 @@ def download_file_artifact(
5757
out_path: Union[str, Path],
5858
) -> bool:
5959
try:
60-
with gretel_object.get_artifact_handle(
61-
artifact_name
62-
) as src, smart_open.open(out_path, "wb") as dest:
60+
with gretel_object.get_artifact_handle(artifact_name) as src, open_artifact(
61+
out_path, "wb"
62+
) as dest:
6363
shutil.copyfileobj(src, dest)
6464
return True
6565
except:
@@ -73,7 +73,7 @@ def download_tar_artifact(
7373
try:
7474
response = requests.get(download_link)
7575
if response.status_code == 200:
76-
with smart_open.open(out_path, "wb") as out:
76+
with open_artifact(out_path, "wb") as out:
7777
out.write(response.content)
7878
except:
7979
logger.warning(f"Failed to download `{artifact_name}`")

src/gretel_trainer/relational/strategies/ancestral.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
from typing import Any, Union
44

55
import pandas as pd
6-
import smart_open
76

87
import gretel_trainer.relational.ancestry as ancestry
98
import gretel_trainer.relational.strategies.common as common
109

10+
from gretel_client.projects.artifact_handlers import open_artifact
1111
from gretel_trainer.relational.core import (
1212
GretelModelConfig,
1313
MultiTableException,
@@ -73,7 +73,7 @@ def prepare_training_data(
7373
tableset=altered_tableset,
7474
ancestral_seeding=True,
7575
)
76-
with smart_open.open(path, "wb") as dest:
76+
with open_artifact(path, "wb") as dest:
7777
data.to_csv(dest, index=False)
7878

7979
return table_paths
@@ -164,7 +164,7 @@ def get_generation_job(
164164
seed_path = output_handler.filepath_for(
165165
f"synthetics_seed_{table}.csv", subdir=subdir
166166
)
167-
with smart_open.open(seed_path, "wb") as dest:
167+
with open_artifact(seed_path, "wb") as dest:
168168
seed_df.to_csv(dest, index=False)
169169
return {"data_source": str(seed_path)}
170170

src/gretel_trainer/relational/strategies/independent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
from typing import Any
55

66
import pandas as pd
7-
import smart_open
87

98
import gretel_trainer.relational.strategies.common as common
109

10+
from gretel_client.projects.artifact_handlers import open_artifact
1111
from gretel_trainer.relational.core import ForeignKey, GretelModelConfig, RelationalData
1212
from gretel_trainer.relational.output_handler import OutputHandler
1313

@@ -49,7 +49,7 @@ def prepare_training_data(
4949
use_columns = [col for col in all_columns if col not in columns_to_drop]
5050

5151
source_path = rel_data.get_table_source(table)
52-
with smart_open.open(source_path, "rb") as src, smart_open.open(
52+
with open_artifact(source_path, "rb") as src, open_artifact(
5353
path, "wb"
5454
) as dest:
5555
pd.DataFrame(columns=use_columns).to_csv(dest, index=False)

src/gretel_trainer/relational/tasks/classify.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import shutil
22

3-
import smart_open
4-
53
import gretel_trainer.relational.tasks.common as common
64

5+
from gretel_client.projects.artifact_handlers import open_artifact
76
from gretel_client.projects.jobs import Job
87
from gretel_client.projects.models import Model
98
from gretel_client.projects.projects import Project
@@ -150,7 +149,7 @@ def _write_results(self, job: Job, table: str) -> None:
150149

151150
destpath = self.output_handler.filepath_for(filename)
152151

153-
with job.get_artifact_handle(artifact_name) as src, smart_open.open(
152+
with job.get_artifact_handle(artifact_name) as src, open_artifact(
154153
str(destpath), "wb"
155154
) as dest:
156155
shutil.copyfileobj(src, dest)

src/gretel_trainer/relational/tasks/synthetics_evaluate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import gretel_trainer.relational.tasks.common as common
1010

11+
from gretel_client.projects.artifact_handlers import open_artifact
1112
from gretel_client.projects.jobs import Job
1213
from gretel_client.projects.models import Model
1314
from gretel_client.projects.projects import Project
@@ -144,7 +145,7 @@ def _read_json_report(model: Model, json_report_filepath: str) -> Optional[dict]
144145
also fails, log a warning and give up gracefully.
145146
"""
146147
try:
147-
return json.loads(smart_open.open(json_report_filepath).read())
148+
return json.loads(open_artifact(json_report_filepath).read())
148149
except:
149150
try:
150151
with model.get_artifact_handle("report_json") as report:

0 commit comments

Comments
 (0)