Skip to content

Commit 42e83fb

Browse files
FilipeFcpyaugenst-flex
authored andcommitted
feat: Persist batch with task IDs before start to allow recovery
1 parent d518f1c commit 42e83fb

File tree

3 files changed

+40
-1
lines changed

3 files changed

+40
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1212

1313
### Fixed
1414
- Fixed missing amplitude factor and handling of negative normal direction case when making adjoint sources from `DiffractionMonitor`.
15+
- Improved the robustness of batch jobs. The batch state, including all `task_ids`, is now saved to `batch.hdf5` immediately after upload. This fixes an issue where an interrupted batch (e.g., due to a kernel crash or network loss) would be unrecoverable.
1516

1617
## [2.9.0] - 2025-08-04
1718

tests/test_web/test_webapi.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Tests webapi and things that depend on it
22
from __future__ import annotations
33

4+
import os
5+
46
import numpy as np
57
import pytest
68
import responses
@@ -656,6 +658,40 @@ def test_create_output_dirs(mock_webapi, tmp_path, monkeypatch):
656658
assert non_existent_dirs_batch.is_dir()
657659

658660

661+
@responses.activate
662+
def test_batch_run_saves_file_after_upload(mock_webapi, mock_job_status, tmp_path, monkeypatch):
663+
"""Test that batch.run() saves batch file with task_ids immediately after upload."""
664+
sims = {TASK_NAME: make_sim()}
665+
batch = Batch(simulations=sims, folder_name=PROJECT_NAME)
666+
667+
batch_file_saved = {"saved": False, "has_task_ids": False}
668+
original_to_file = Batch.to_file
669+
670+
def track_to_file(self, fname):
671+
batch_file_saved["saved"] = True
672+
batch_file_saved["has_task_ids"] = self.jobs is not None and TASK_NAME in self.jobs
673+
return original_to_file(self, fname)
674+
675+
# mock start to interrupt run() after upload and to_file
676+
def mock_start_interrupt(self):
677+
# at this point, upload() and to_file() should have been called
678+
assert batch_file_saved["saved"], "Batch file should be saved before start()"
679+
assert batch_file_saved["has_task_ids"], "Batch file should have task_ids"
680+
# verify file actually exists and can be loaded
681+
batch_path = self._batch_path(path_dir=str(tmp_path))
682+
assert os.path.exists(batch_path)
683+
recovered = Batch.from_file(batch_path)
684+
assert recovered.jobs[TASK_NAME].task_id == TASK_ID
685+
raise RuntimeError("Simulated interruption after upload")
686+
687+
monkeypatch.setattr(Batch, "to_file", track_to_file)
688+
monkeypatch.setattr(Batch, "start", mock_start_interrupt)
689+
690+
# run should save the batch file after upload, even if interrupted
691+
with pytest.raises(RuntimeError, match="Simulated interruption"):
692+
batch.run(path_dir=str(tmp_path))
693+
694+
659695
""" Async """
660696

661697

tidy3d/web/api/container.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,7 @@ def run(self, path_dir: str = DEFAULT_DATA_DIR) -> BatchData:
601601
"""
602602
self._check_path_dir(path_dir)
603603
self.upload()
604+
self.to_file(self._batch_path(path_dir=path_dir))
604605
self.start()
605606
self.monitor()
606607
return self.load(path_dir=path_dir)
@@ -988,7 +989,6 @@ def load(self, path_dir: str = DEFAULT_DATA_DIR, replace_existing: bool = False)
988989
allowing one to load this :class:`Batch` later using ``batch = Batch.from_file()``.
989990
"""
990991
self._check_path_dir(path_dir=path_dir)
991-
self.download(path_dir=path_dir, replace_existing=replace_existing)
992992

993993
if self.jobs is None:
994994
raise DataError("Can't load batch results, hasn't been uploaded.")
@@ -1010,6 +1010,8 @@ def load(self, path_dir: str = DEFAULT_DATA_DIR, replace_existing: bool = False)
10101010
job_data = data[task_name]
10111011
job.simulation._patch_data(data=job_data)
10121012

1013+
self.download(path_dir=path_dir, replace_existing=replace_existing)
1014+
10131015
return data
10141016

10151017
def delete(self) -> None:

0 commit comments

Comments
 (0)