Skip to content

Commit ccc0f42

Browse files
marcorudolphflexyaugenst-flex
authored andcommitted
fix(tidy3d): FXC-4742-avoid-any-non-temporary-hdf-5-writing-in-tests
1 parent 6cdf059 commit ccc0f42

File tree

3 files changed

+44
-30
lines changed

3 files changed

+44
-30
lines changed

tests/test_plugins/test_mode_solver.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def test_mode_solver_fields():
384384

385385
@pytest.mark.parametrize("local", [True, False])
386386
@responses.activate
387-
def test_mode_solver_simple(mock_remote_api, local):
387+
def test_mode_solver_simple(mock_remote_api, local, tmp_path):
388388
"""Simple mode solver run (with symmetry)"""
389389

390390
simulation = td.Simulation(
@@ -423,7 +423,7 @@ def test_mode_solver_simple(mock_remote_api, local):
423423
check_ms_reduction(ms)
424424

425425
else:
426-
_ = msweb.run(ms)
426+
_ = msweb.run(ms, results_file=tmp_path / "tmp.hdf5")
427427

428428
# Testing issue 807 functions
429429
freq0 = td.C_0 / 1.55
@@ -888,7 +888,7 @@ def test_group_index(mock_remote_api, local, tmp_path):
888888
mode_spec=mode_spec.copy(update={"group_index_step": True}),
889889
freqs=freqs,
890890
)
891-
modes = ms.solve() if local else msweb.run(ms)
891+
modes = ms.solve() if local else msweb.run(ms, results_file=tmp_path / "tmp.hdf5")
892892
if local:
893893
assert (modes.n_group.sel(mode_index=0).values > 3.9).all()
894894
assert (modes.n_group.sel(mode_index=0).values < 4.2).all()
@@ -1030,7 +1030,7 @@ def test_mode_solver_method_defaults():
10301030

10311031

10321032
@responses.activate
1033-
def test_mode_solver_web_run_batch(mock_remote_api):
1033+
def test_mode_solver_web_run_batch(mock_remote_api, tmp_path):
10341034
"""Testing run_batch function for the web mode solver."""
10351035

10361036
wav = 1.5
@@ -1065,7 +1065,13 @@ def test_mode_solver_web_run_batch(mock_remote_api):
10651065
)
10661066

10671067
# Run mode solver one at a time
1068-
results = msweb.run_batch(mode_solver_list, verbose=False, folder_name="Mode Solver")
1068+
results_files = [tmp_path / f"ms_batch_{i}.hdf5" for i in range(num_of_sims)]
1069+
results = msweb.run_batch(
1070+
mode_solver_list,
1071+
verbose=False,
1072+
folder_name="Mode Solver",
1073+
results_files=results_files,
1074+
)
10691075
print(*results, sep="\n")
10701076
assert all(isinstance(x, ModeSolverData) for x in results)
10711077
assert (results[i].n_eff.shape == (num_freqs, i + 1) for i in range(num_of_sims))
@@ -1144,7 +1150,7 @@ def test_mode_solver_plot():
11441150

11451151
@pytest.mark.parametrize("local", [True, False])
11461152
@responses.activate
1147-
def test_modes_eme_sim(mock_remote_api, local):
1153+
def test_modes_eme_sim(mock_remote_api, local, tmp_path):
11481154
lambda0 = 1
11491155
freq0 = td.C_0 / lambda0
11501156
sim_size = (1, 1, 1)
@@ -1161,8 +1167,10 @@ def test_modes_eme_sim(mock_remote_api, local):
11611167
_ = solver.data
11621168
else:
11631169
with pytest.raises(SetupError):
1164-
_ = msweb.run(solver)
1165-
_ = msweb.run(solver.to_fdtd_mode_solver())
1170+
_ = msweb.run(solver, results_file=tmp_path / "eme_solver_remote.hdf5")
1171+
_ = msweb.run(
1172+
solver.to_fdtd_mode_solver(), results_file=tmp_path / "eme_solver_fdtd_remote.hdf5"
1173+
)
11661174

11671175
_ = solver.reduced_simulation_copy
11681176

tests/test_web/test_local_cache.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def _test_load_simulation_if_cached(monkeypatch, tmp_path, basic_simulation):
285285
assert counters == {"upload": 1, "start": 1, "monitor": 1, "download": 1}
286286
assert len(cache) == 1
287287

288-
sim_data_from_cache = load_simulation_if_cached(basic_simulation)
288+
sim_data_from_cache = load_simulation_if_cached(basic_simulation, path=tmp_path / "tmp.hdf5")
289289
assert sim_data_from_cache is not None
290290
assert sim_data_from_cache.simulation == basic_simulation
291291

@@ -296,28 +296,28 @@ def _test_load_simulation_if_cached(monkeypatch, tmp_path, basic_simulation):
296296

297297
def _test_mode_solver_caching(monkeypatch, tmp_path):
298298
counters = _patch_run_pipeline(monkeypatch)
299-
299+
tmp_file = tmp_path / "tmp.hdf5"
300300
# store in cache
301301
mode_sim = make_mode_sim()
302-
mode_sim_data = web.run(mode_sim)
302+
mode_sim_data = web.run(mode_sim, path=tmp_file)
303303

304304
# test basic loading from cache
305-
from_cache_data = load_simulation_if_cached(mode_sim)
305+
from_cache_data = load_simulation_if_cached(mode_sim, path=tmp_file)
306306
assert from_cache_data is not None
307307
assert isinstance(from_cache_data, _FakeStubData)
308308
assert mode_sim_data.simulation == from_cache_data.simulation
309309

310310
# test loading from run
311311
_reset_counters(counters)
312-
mode_sim_data_run = web.run(mode_sim)
312+
mode_sim_data_run = web.run(mode_sim, path=tmp_file)
313313
assert counters["download"] == 0
314314
assert isinstance(mode_sim_data_run, _FakeStubData)
315315
assert mode_sim_data.simulation == mode_sim_data_run.simulation
316316

317317
# test loading from job
318318
_reset_counters(counters)
319319
job = Job(simulation=mode_sim, task_name="test")
320-
job_data = job.run()
320+
job_data = job.run(path=tmp_file)
321321
assert counters["download"] == 0
322322
assert isinstance(job_data, _FakeStubData)
323323
assert mode_sim_data.simulation == job_data.simulation
@@ -334,14 +334,14 @@ def _test_mode_solver_caching(monkeypatch, tmp_path):
334334
cache = resolve_local_cache(True)
335335
# test storing via job
336336
cache.clear()
337-
Job(simulation=mode_sim, task_name="test").run()
338-
assert load_simulation_if_cached(mode_sim) is not None
337+
Job(simulation=mode_sim, task_name="test").run(path=tmp_file)
338+
assert load_simulation_if_cached(mode_sim, path=tmp_file) is not None
339339

340340
# test storing via batch
341341
cache.clear()
342342
batch_mode_data = Batch(simulations={"sim1": mode_sim}).run(path_dir=tmp_path)
343343
_ = batch_mode_data["sim1"] # access to store
344-
assert load_simulation_if_cached(mode_sim) is not None
344+
assert load_simulation_if_cached(mode_sim, path=tmp_file) is not None
345345

346346

347347
def _test_run_cache_hit_async(monkeypatch, basic_simulation, tmp_path):
@@ -382,7 +382,7 @@ def _test_run_cache_hit_async(monkeypatch, basic_simulation, tmp_path):
382382
assert len(cache) == 3
383383

384384

385-
def _test_verbosity(monkeypatch, basic_simulation):
385+
def _test_verbosity(monkeypatch, basic_simulation, tmp_path):
386386
_CSI_RE = re.compile(r"\x1b\[[0-?]*[ -/]*[@-~]") # ANSI CSI
387387
_OSC8_RE = re.compile(r"\x1b\]8;.*?(?:\x1b\\|\x07)", re.DOTALL) # OSC-8 hyperlinks
388388

@@ -403,8 +403,8 @@ def _normalize_console_text(s: str) -> str:
403403
_reset_counters(counters)
404404
sim2 = basic_simulation.updated_copy(shutoff=1e-4)
405405
sim3 = basic_simulation.updated_copy(shutoff=1e-3)
406-
407-
run(basic_simulation, verbose=True) # seed cache
406+
tmp_file = tmp_path / "tmp.hdf5"
407+
run(basic_simulation, verbose=True, path=tmp_file) # seed cache
408408

409409
log_mod = importlib.import_module("tidy3d.log")
410410

@@ -424,22 +424,22 @@ def _normalize_console_text(s: str) -> str:
424424
buf.seek(0)
425425

426426
# test for load_simulation_if_cached
427-
sim_data = load_simulation_if_cached(basic_simulation, verbose=True)
427+
sim_data = load_simulation_if_cached(basic_simulation, verbose=True, path=tmp_file)
428428
assert sim_data is not None
429429
assert "Loading simulation from" in buf.getvalue(), (
430430
f"Expected 'Loading simulation from' in log, got '{buf.getvalue()}'"
431431
)
432432

433433
buf.truncate(0)
434434
buf.seek(0)
435-
load_simulation_if_cached(basic_simulation, verbose=False)
435+
load_simulation_if_cached(basic_simulation, verbose=False, path=tmp_file)
436436
assert sim_data is not None
437437
assert buf.getvalue().strip() == "", f"Expected empty log, got '{buf.getvalue()}'"
438438

439439
# test for batched runs
440440
buf.truncate(0)
441441
buf.seek(0)
442-
run([basic_simulation, sim3], verbose=True)
442+
run([basic_simulation, sim3], verbose=True, path=tmp_path)
443443
txt = _normalize_console_text(buf.getvalue())
444444
assert "Got 1 simulation from cache" in txt, (
445445
f"Expected 'Got 1 simulation from cache' in log, got '{buf.getvalue()}'"
@@ -448,13 +448,13 @@ def _normalize_console_text(s: str) -> str:
448448
# if some found
449449
buf.truncate(0)
450450
buf.seek(0)
451-
run([basic_simulation, sim2], verbose=False)
451+
run([basic_simulation, sim2], verbose=False, path=tmp_path)
452452
assert buf.getvalue().strip() == "", f"Expected empty log, got '{buf.getvalue()}'"
453453

454454
# if all found
455455
buf.truncate(0)
456456
buf.seek(0)
457-
run([basic_simulation, sim2], verbose=False)
457+
run([basic_simulation, sim2], verbose=False, path=tmp_path)
458458
assert buf.getvalue().strip() == "", f"Expected empty log, got '{buf.getvalue()}'"
459459

460460
finally:
@@ -467,7 +467,7 @@ def _test_job_run_cache(monkeypatch, basic_simulation, tmp_path):
467467
cache = resolve_local_cache(use_cache=True)
468468
cache.clear()
469469
job = Job(simulation=basic_simulation, task_name="test")
470-
job.run()
470+
job.run(path=tmp_path / "tmp.hdf5")
471471

472472
assert len(cache) == 1
473473

@@ -485,7 +485,7 @@ def _test_job_run_cache(monkeypatch, basic_simulation, tmp_path):
485485
assert os.path.exists(out2_path)
486486

487487

488-
def _test_autograd_cache(monkeypatch, request):
488+
def _test_autograd_cache(monkeypatch, request, tmp_path):
489489
counters = _patch_run_pipeline(monkeypatch)
490490

491491
# "Original" rule: the one autograd uses by default
@@ -526,7 +526,7 @@ def _restore_make_dict_vjp():
526526
def objective(params):
527527
sim = make_sim(params)
528528
sim.attrs["params"] = params
529-
sim_data = run_autograd(sim)
529+
sim_data = run_autograd(sim, path=tmp_path / "tmp.hdf5")
530530
value = postprocess(sim_data)
531531
return value
532532

@@ -823,9 +823,9 @@ def test_cache_sequential(
823823
_test_cache_stats_sync(monkeypatch, tmp_path_factory, basic_simulation)
824824
_test_run_cache_hit_async(monkeypatch, basic_simulation, tmp_path)
825825
_test_job_run_cache(monkeypatch, basic_simulation, tmp_path)
826-
_test_autograd_cache(monkeypatch, request)
826+
_test_autograd_cache(monkeypatch, request, tmp_path)
827827
_test_configure_cache_roundtrip(monkeypatch, tmp_path)
828828
_test_store_and_fetch_do_not_iterate(monkeypatch, tmp_path, basic_simulation)
829829
_test_mode_solver_caching(monkeypatch, tmp_path)
830-
_test_verbosity(monkeypatch, basic_simulation)
830+
_test_verbosity(monkeypatch, basic_simulation, tmp_path)
831831
_test_cache_cli_commands(monkeypatch, tmp_path_factory, basic_simulation)

tests/test_web/test_webapi.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,8 @@ def save_sim_to_path(path: str) -> None:
852852
PROJECT_NAME,
853853
"--inspect_credits",
854854
"--inspect_sim",
855+
"-o",
856+
str(tmp_path / "tmp.hdf5"),
855857
]
856858
)
857859

@@ -865,6 +867,8 @@ def save_sim_to_path(path: str) -> None:
865867
"--folder_name",
866868
PROJECT_NAME,
867869
"--inspect_credits",
870+
"-o",
871+
str(tmp_path / "tmp.hdf5"),
868872
]
869873
)
870874

@@ -877,6 +881,8 @@ def save_sim_to_path(path: str) -> None:
877881
"--folder_name",
878882
PROJECT_NAME,
879883
"--inspect_sim",
884+
"-o",
885+
str(tmp_path / "tmp.hdf5"),
880886
]
881887
)
882888

0 commit comments

Comments
 (0)