Skip to content
This repository was archived by the owner on Feb 18, 2026. It is now read-only.

Commit f175bb9

Browse files
committed
Don't assign job until after it is submitted
GitOrigin-RevId: ed7badb8d98cc7bc0fbbd8c44b59b9548cd3d2a8
1 parent c9d3982 commit f175bb9

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

src/gretel_trainer/benchmark/gretel/strategy_sdk.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,12 @@ def get_generate_time(self) -> Optional[float]:
6464
def train(self) -> None:
6565
model_config = self._format_model_config()
6666
data_source = self.artifact_key or self.dataset.data_source
67-
self.model = self.project.create_model_obj(
67+
_model = self.project.create_model_obj(
6868
model_config=model_config, data_source=data_source
6969
)
7070
# Calling this in lieu of submit_cloud() is supposed to avoid
7171
# artifact upload. Doesn't work for more recent client versions!
72-
self.model.submit(runner_mode=RunnerMode.CLOUD)
72+
self.model = _model.submit(runner_mode=RunnerMode.CLOUD)
7373
job_status = self._await_job(self.model, "training")
7474
if job_status in END_STATES and job_status != Status.COMPLETED:
7575
raise BenchmarkException("Training failed")
@@ -78,10 +78,10 @@ def generate(self) -> None:
7878
if self.model is None:
7979
raise BenchmarkException("Cannot generate before training")
8080

81-
self.record_handler = self.model.create_record_handler_obj(
81+
_record_handler = self.model.create_record_handler_obj(
8282
params={"num_records": self.dataset.row_count}
8383
)
84-
self.record_handler.submit_cloud()
84+
self.record_handler = _record_handler.submit_cloud()
8585
job_status = self._await_job(self.record_handler, "generation")
8686
if job_status == Status.COMPLETED:
8787
self._download_synthetic_data(self.record_handler)

tests/benchmark/test_benchmark.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,12 +209,14 @@ def test_run_happy_path_gretel_sdk(
209209
status=Status.COMPLETED,
210210
billing_details={"total_time_seconds": 15},
211211
)
212+
record_handler.submit_cloud.return_value = record_handler
212213

213214
model = Mock(
214215
status=Status.COMPLETED,
215216
billing_details={"total_time_seconds": 30},
216217
)
217218
model.create_record_handler_obj.return_value = record_handler
219+
model.submit.return_value = model
218220

219221
evaluate_model = Mock(
220222
status=Status.COMPLETED,
@@ -268,6 +270,7 @@ def test_sdk_model_failure(working_dir, iris, project):
268270
status=Status.ERROR,
269271
billing_details={"total_time_seconds": 30},
270272
)
273+
model.submit.return_value = model
271274

272275
project.create_model_obj.side_effect = [model]
273276

@@ -309,6 +312,7 @@ def test_custom_gretel_model_configs_do_not_overwrite_each_other(
309312
status=Status.ERROR,
310313
billing_details={"total_time_seconds": 30},
311314
)
315+
model.submit.return_value = model
312316
project.create_model_obj.return_value = model
313317

314318
pets = create_dataset(df, datatype="tabular", name="pets")

0 commit comments

Comments
 (0)