Skip to content

Commit 2f2f081

Browse files
yaugenst-flextylerflex
authored andcommitted
return emulated batch data instead of dict from emulated run async
1 parent 09f8e08 commit 2f2f081

File tree

3 files changed

+15
-4
lines changed

3 files changed

+15
-4
lines changed

tests/test_components/test_autograd.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,12 +220,23 @@ def emulated_run_async_fwd(simulations, **run_kwargs) -> td.SimulationData:
220220
batch_data_orig[task_name] = sim_data_orig
221221
task_ids_fwd[task_name] = task_id_fwd
222222

223-
return batch_data_orig, task_ids_fwd
223+
class EmulatedBatchData(web.BatchData):
224+
def load_sim_data(self, task_name):
225+
return batch_data_orig[task_name]
226+
227+
task_paths = {task_name: "" for task_name in simulations.keys()}
228+
229+
batch_data = EmulatedBatchData(
230+
task_paths=task_paths,
231+
task_ids=task_ids_fwd,
232+
verbose=False,
233+
)
234+
235+
return batch_data, task_ids_fwd
224236

225237
def emulated_run_async_bwd(simulations, **run_kwargs) -> td.SimulationData:
226238
vjp_dict = {}
227239
for task_name, simulation in simulations.items():
228-
task_id_fwd = task_name[:-8]
229240
vjp_dict[task_name] = emulated_run_bwd(simulation, task_name, **run_kwargs)
230241
return vjp_dict
231242

tidy3d/web/api/autograd/autograd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,7 @@ def vjp(data_fields_dict_vjp: dict[str, AutogradFieldMap]) -> dict[str, Autograd
712712
sim_data_fwd = sim_data_fwd_dict[task_name]
713713
sim_fields_keys = sim_fields_keys_dict[task_name]
714714

715-
sim_data_adj = batch_data_adj.get(task_name_adj)
715+
sim_data_adj = batch_data_adj[task_name_adj]
716716

717717
sim_fields_vjp = postprocess_adj(
718718
sim_data_adj=sim_data_adj,

0 commit comments

Comments
 (0)