@@ -1071,7 +1071,7 @@ def test_strict_types():
1071
1071
_ = JaxBox (size = (1 , 1 , [1 , 2 ]), center = (0 , 0 , 0 ))
1072
1072
1073
1073
1074
- def _test_polyslab_box (use_emulated_run ):
1074
+ def _test_polyslab_box (use_emulated_run , tmp_path ):
1075
1075
"""Make sure box made with polyslab gives equivalent gradients.
1076
1076
Note: doesn't pass now since JaxBox samples the permittivity inside and outside the box,
1077
1077
and a random permittivity data is created by the emulated run function. JaxPolySlab just
@@ -1143,7 +1143,7 @@ def f(size, center, is_box=True):
1143
1143
],
1144
1144
)
1145
1145
1146
- sim_data = run (sim , task_name = "test" )
1146
+ sim_data = run (sim , task_name = "test" , path = str ( tmp_path / RUN_FILE ) )
1147
1147
amp = extract_amp (sim_data )
1148
1148
return objective (amp )
1149
1149
@@ -1171,7 +1171,7 @@ def f(size, center, is_box=True):
1171
1171
1172
1172
1173
1173
@pytest .mark .parametrize ("sim_size_axis" , [0 , 10 ])
1174
- def test_polyslab_2d (sim_size_axis , use_emulated_run ):
1174
+ def test_polyslab_2d (sim_size_axis , use_emulated_run , tmp_path ):
1175
1175
"""Make sure box made with polyslab gives equivalent gradients (note, doesn't pass now)."""
1176
1176
1177
1177
np .random .seed (0 )
@@ -1239,7 +1239,7 @@ def f(size, center):
1239
1239
],
1240
1240
)
1241
1241
1242
- sim_data = run (sim , task_name = "test" )
1242
+ sim_data = run (sim , task_name = "test" , path = str ( tmp_path / RUN_FILE ) )
1243
1243
amp = extract_amp (sim_data )
1244
1244
return objective (amp )
1245
1245
@@ -1325,14 +1325,14 @@ def test_diff_data_angles(axis):
1325
1325
assert np .isclose (zeroth_order_theta , 0.0 )
1326
1326
1327
1327
1328
- def _test_error_regular_web ():
1328
+ def _test_error_regular_web (tmp_path ):
1329
1329
"""Test that a custom error is raised if running tidy3d through web.run()"""
1330
1330
1331
1331
sim = make_sim (permittivity = EPS , size = SIZE , vertices = VERTICES , base_eps_val = BASE_EPS_VAL )
1332
1332
import tidy3d .web as web
1333
1333
1334
1334
with pytest .raises (ValueError ):
1335
- web .run (sim , task_name = "test" )
1335
+ web .run (sim , task_name = "test" , path = str ( tmp_path / RUN_FILE ) )
1336
1336
1337
1337
1338
1338
def test_value_filter ():
@@ -1376,7 +1376,7 @@ def test_save_load_simdata(use_emulated_run, tmp_path):
1376
1376
assert sim_data == sim_data2
1377
1377
1378
1378
1379
- def _test_polyslab_scale (use_emulated_run ):
1379
+ def _test_polyslab_scale (use_emulated_run , tmp_path ):
1380
1380
"""Make sure box made with polyslab gives equivalent gradients (note, doesn't pass now)."""
1381
1381
1382
1382
nums = np .logspace (np .log10 (3 ), 3 , 13 )
@@ -1441,7 +1441,7 @@ def f(scale=1.0, vertices=vertices):
1441
1441
],
1442
1442
)
1443
1443
1444
- sim_data = run (sim , task_name = "test" )
1444
+ sim_data = run (sim , task_name = "test" , path = str ( tmp_path / RUN_FILE ) )
1445
1445
amp = extract_amp (sim_data )
1446
1446
return objective (amp )
1447
1447
@@ -1692,13 +1692,13 @@ def test_sim_data_plot_field(use_emulated_run, tmp_path):
1692
1692
"""Test splitting of regular simulation data into user and server data."""
1693
1693
1694
1694
jax_sim = make_sim (permittivity = EPS , size = SIZE , vertices = VERTICES , base_eps_val = BASE_EPS_VAL )
1695
- jax_sim_data = run (jax_sim , task_name = "test" )
1695
+ jax_sim_data = run (jax_sim , task_name = "test" , path = str ( tmp_path / RUN_FILE ) )
1696
1696
ax = jax_sim_data .plot_field ("field" , "Ez" , "real" , f = 1e14 )
1697
1697
# plt.show()
1698
1698
assert len (ax .collections ) == 1
1699
1699
1700
1700
1701
- def test_pytreedef_errors (use_emulated_run ):
1701
+ def test_pytreedef_errors (use_emulated_run , tmp_path ):
1702
1702
"""Fix errors that occur when jax doesnt know how to handle array types in aux_data."""
1703
1703
1704
1704
vertices = [(0 , 0 ), (1 , 0 ), (1 , 1 ), (0 , 1 )]
@@ -1758,7 +1758,7 @@ def f(x):
1758
1758
boundary_spec = td .BoundarySpec .pml (x = False , y = False , z = False ),
1759
1759
)
1760
1760
1761
- sd = run (sim , task_name = "test" )
1761
+ sd = run (sim , task_name = "test" , path = str ( tmp_path / RUN_FILE ) )
1762
1762
1763
1763
return jnp .sum (jnp .abs (jnp .array (sd ["test" ].amps .values )))
1764
1764
0 commit comments