Skip to content

Commit c31578f

Browse files
yaugenst-flexmomchil-flex
authored andcommitted
fix adjoint tests
1 parent d97d133 commit c31578f

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

tests/test_plugins/test_adjoint.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,7 +1071,7 @@ def test_strict_types():
10711071
_ = JaxBox(size=(1, 1, [1, 2]), center=(0, 0, 0))
10721072

10731073

1074-
def _test_polyslab_box(use_emulated_run):
1074+
def _test_polyslab_box(use_emulated_run, tmp_path):
10751075
"""Make sure box made with polyslab gives equivalent gradients.
10761076
Note: doesn't pass now since JaxBox samples the permittivity inside and outside the box,
10771077
and a random permittivity data is created by the emulated run function. JaxPolySlab just
@@ -1143,7 +1143,7 @@ def f(size, center, is_box=True):
11431143
],
11441144
)
11451145

1146-
sim_data = run(sim, task_name="test")
1146+
sim_data = run(sim, task_name="test", path=str(tmp_path / RUN_FILE))
11471147
amp = extract_amp(sim_data)
11481148
return objective(amp)
11491149

@@ -1171,7 +1171,7 @@ def f(size, center, is_box=True):
11711171

11721172

11731173
@pytest.mark.parametrize("sim_size_axis", [0, 10])
1174-
def test_polyslab_2d(sim_size_axis, use_emulated_run):
1174+
def test_polyslab_2d(sim_size_axis, use_emulated_run, tmp_path):
11751175
"""Make sure box made with polyslab gives equivalent gradients (note, doesn't pass now)."""
11761176

11771177
np.random.seed(0)
@@ -1239,7 +1239,7 @@ def f(size, center):
12391239
],
12401240
)
12411241

1242-
sim_data = run(sim, task_name="test")
1242+
sim_data = run(sim, task_name="test", path=str(tmp_path / RUN_FILE))
12431243
amp = extract_amp(sim_data)
12441244
return objective(amp)
12451245

@@ -1325,14 +1325,14 @@ def test_diff_data_angles(axis):
13251325
assert np.isclose(zeroth_order_theta, 0.0)
13261326

13271327

1328-
def _test_error_regular_web():
1328+
def _test_error_regular_web(tmp_path):
13291329
"""Test that a custom error is raised if running tidy3d through web.run()"""
13301330

13311331
sim = make_sim(permittivity=EPS, size=SIZE, vertices=VERTICES, base_eps_val=BASE_EPS_VAL)
13321332
import tidy3d.web as web
13331333

13341334
with pytest.raises(ValueError):
1335-
web.run(sim, task_name="test")
1335+
web.run(sim, task_name="test", path=str(tmp_path / RUN_FILE))
13361336

13371337

13381338
def test_value_filter():
@@ -1376,7 +1376,7 @@ def test_save_load_simdata(use_emulated_run, tmp_path):
13761376
assert sim_data == sim_data2
13771377

13781378

1379-
def _test_polyslab_scale(use_emulated_run):
1379+
def _test_polyslab_scale(use_emulated_run, tmp_path):
13801380
"""Make sure box made with polyslab gives equivalent gradients (note, doesn't pass now)."""
13811381

13821382
nums = np.logspace(np.log10(3), 3, 13)
@@ -1441,7 +1441,7 @@ def f(scale=1.0, vertices=vertices):
14411441
],
14421442
)
14431443

1444-
sim_data = run(sim, task_name="test")
1444+
sim_data = run(sim, task_name="test", path=str(tmp_path / RUN_FILE))
14451445
amp = extract_amp(sim_data)
14461446
return objective(amp)
14471447

@@ -1692,13 +1692,13 @@ def test_sim_data_plot_field(use_emulated_run, tmp_path):
16921692
"""Test splitting of regular simulation data into user and server data."""
16931693

16941694
jax_sim = make_sim(permittivity=EPS, size=SIZE, vertices=VERTICES, base_eps_val=BASE_EPS_VAL)
1695-
jax_sim_data = run(jax_sim, task_name="test")
1695+
jax_sim_data = run(jax_sim, task_name="test", path=str(tmp_path / RUN_FILE))
16961696
ax = jax_sim_data.plot_field("field", "Ez", "real", f=1e14)
16971697
# plt.show()
16981698
assert len(ax.collections) == 1
16991699

17001700

1701-
def test_pytreedef_errors(use_emulated_run):
1701+
def test_pytreedef_errors(use_emulated_run, tmp_path):
17021702
"""Fix errors that occur when jax doesnt know how to handle array types in aux_data."""
17031703

17041704
vertices = [(0, 0), (1, 0), (1, 1), (0, 1)]
@@ -1758,7 +1758,7 @@ def f(x):
17581758
boundary_spec=td.BoundarySpec.pml(x=False, y=False, z=False),
17591759
)
17601760

1761-
sd = run(sim, task_name="test")
1761+
sd = run(sim, task_name="test", path=str(tmp_path / RUN_FILE))
17621762

17631763
return jnp.sum(jnp.abs(jnp.array(sd["test"].amps.values)))
17641764

0 commit comments

Comments
 (0)