Skip to content

Commit edc14a9

Browse files
Prototype new flow
1 parent 43dcd42 commit edc14a9

File tree

1 file changed

+131
-49
lines changed

1 file changed

+131
-49
lines changed

tidy3d/plugins/smatrix/run.py

Lines changed: 131 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -15,52 +15,96 @@
1515
DEFAULT_DATA_DIR = "."
1616

1717

18-
def create_batch(
19-
modeler: ComponentModelerType,
20-
path_dir: str = DEFAULT_DATA_DIR,
21-
parent_batch_id: str = None,
22-
group_id: str = None,
23-
file_name: str = "batch.hdf5",
24-
**kwargs,
25-
) -> Batch:
26-
"""Creates a simulation Batch from a component modeler and saves it to a file.
18+
def compose_simulation_data_index(
19+
port_task_map: dict[port, str]
20+
) -> IndexSimulation:
21+
port_data_dict = {}
22+
for port, task in port_task_map.items():
23+
# get simulationdata for each port
24+
port_data_dict[port] = sim_data_i
25+
26+
return IndexSimulationData(
27+
index=port_data_dict.keys(),
28+
data=port_data_dict.values()
29+
)
30+
31+
def compose_terminal_modeler_data(
32+
modeler: TerminalComponentModeler,
33+
port_task_map: dict[port, str]
34+
) -> TerminalComponentModelerData:
35+
"""Assembles `TerminalComponentModelerData` from simulation results.
36+
37+
This function maps the simulation data from a completed batch run back to the
38+
ports of the terminal component modeler.
2739
2840
Args:
29-
modeler: The component modeler that defines the set of simulations.
30-
path_dir: Directory where the batch file will be saved.
31-
file_name: Name for the HDF5 file where the batch is stored.
32-
**kwargs: Additional keyword arguments passed to the `Batch` constructor.
41+
modeler: The `TerminalComponentModeler` used to generate the simulations.
3342
3443
Returns:
35-
The configured `Batch` object ready for execution.
44+
A `TerminalComponentModelerData` object containing the results mapped to
45+
their respective ports.
3646
"""
37-
filepath = os.path.join(path_dir, file_name)
47+
port_simulation_data = compose_simulation_data_index(port_task_map)
48+
return TerminalComponentModelerData(modeler=modeler, data=port_simulation_data)
3849

39-
if parent_batch_id is not None:
40-
parent_task_dict = dict()
41-
for key in modeler.sim_dict.keys():
42-
parent_task_dict[key] = (parent_batch_id,)
43-
else:
44-
parent_task_dict = None
50+
def compose_component_modeler_data(
51+
modeler: ComponentModeler,
52+
port_task_map: dict[port, str]
53+
) -> ComponentModelerData:
54+
"""Assembles `ComponentModelerData` from simulation results.
4555
46-
if group_id is not None:
47-
group_id_dict = dict()
48-
for key in modeler.sim_dict.keys():
49-
group_id_dict[key] = (group_id,)
50-
else:
51-
group_id_dict = None
56+
This function maps the simulation data from a completed batch run back to the
57+
ports of the component modeler.
5258
53-
batch = Batch(
54-
simulations=modeler.sim_dict,
55-
parent_tasks=parent_task_dict,
56-
group_ids=group_id_dict,
57-
**kwargs,
58-
)
59-
batch.to_file(filepath)
60-
return batch
59+
Args:
60+
modeler: The `ComponentModeler` used to generate the simulations.
61+
batch_data: The results obtained from running the simulation `Batch`.
6162
63+
Returns:
64+
A `ComponentModelerData` object containing the results mapped to
65+
their respective ports.
66+
"""
67+
port_simulation_data = compose_simulation_data_index(port_task_map)
68+
return ComponentModelerData(modeler=modeler, data=port_simulation_data)
6269

63-
def compose_terminal_modeler_data(
70+
71+
def compose_modeler_data(
72+
modeler_file: str,
73+
port_task_map: dict[port, str]
74+
) -> ComponentModelerDataType:
75+
"""Selects the correct composer based on the modeler type and creates the data object.
76+
77+
This method acts as a dispatcher, inspecting the type of `modeler` to determine
78+
which composer function (`compose_component_modeler_data` or
79+
`compose_terminal_modeler_data`) to invoke.
80+
81+
Args:
82+
modeler: The component modeler, which can be either a `ComponentModeler` or
83+
a `TerminalComponentModeler`.
84+
batch_data: The results obtained from running the simulation `Batch`.
85+
86+
Returns:
87+
The appropriate `ComponentModelerDataType` object containing the simulation results.
88+
89+
Raises:
90+
TypeError: If the provided `modeler` is not a recognized type.
91+
"""
92+
json_str = Tidy3dBaseModel._json_string_from_hdf5(modeler_file)
93+
model_dict = json.loads(json_str)
94+
modeler = model_dict["type"]
95+
96+
if modeler_type == "ComponentModeler":
97+
modeler = ComponentModeler.from_file(task_file)
98+
modeler_data = compose_component_modeler_data(modeler=modeler, port_task_map=port_task_map)
99+
elif modeler_type == "TerminalComponentModeler":
100+
modeler = TerminalComponentModeler.from_file(task_file)
101+
modeler_data = compose_terminal_modeler_data(modeler=modeler, port_task_map=port_task_map)
102+
else:
103+
raise TypeError(f"Unsupported modeler type: {type(modeler).__name__}")
104+
return modeler_data
105+
106+
107+
def compose_terminal_modeler_data_from_batch_data(
64108
modeler: TerminalComponentModeler,
65109
batch_data: BatchData,
66110
) -> TerminalComponentModelerData:
@@ -83,9 +127,8 @@ def compose_terminal_modeler_data(
83127
return TerminalComponentModelerData(modeler=modeler, data=port_simulation_data)
84128

85129

86-
def compose_component_modeler_data(
130+
def compose_component_modeler_data_from_batch_data(
87131
modeler: ComponentModeler,
88-
simulation_data_list: Optional[list[SimulationData]] = None,
89132
batch_data: Optional[BatchData] = None,
90133
) -> ComponentModelerData:
91134
"""Assembles `ComponentModelerData` from simulation results.
@@ -101,18 +144,13 @@ def compose_component_modeler_data(
101144
A `ComponentModelerData` object containing the results mapped to
102145
their respective ports.
103146
"""
104-
if simulation_data_list:
105-
pass
106-
elif batch_data:
107-
ports = [modeler.get_task_name(port=port_i) for port_i in modeler.ports]
108-
data = [batch_data[modeler.get_task_name(port=port_i)] for port_i in modeler.ports]
109-
port_simulation_data = IndexSimulationData(ports=ports, data=data)
110-
else:
111-
pass
147+
ports = [modeler.get_task_name(port=port_i) for port_i in modeler.ports]
148+
data = [batch_data[modeler.get_task_name(port=port_i)] for port_i in modeler.ports]
149+
port_simulation_data = IndexSimulationData(ports=ports, data=data)
112150
return ComponentModelerData(modeler=modeler, data=port_simulation_data)
113151

114152

115-
def compose_modeler_data(
153+
def compose_modeler_data_from_batch_data(
116154
modeler: ComponentModelerType,
117155
batch_data: BatchData,
118156
) -> ComponentModelerDataType:
@@ -134,15 +172,59 @@ def compose_modeler_data(
134172
TypeError: If the provided `modeler` is not a recognized type.
135173
"""
136174
if isinstance(modeler, ComponentModeler):
137-
modeler_data = compose_component_modeler_data(modeler=modeler, batch_data=batch_data)
175+
modeler_data = compose_component_modeler_data_from_batch_data(modeler=modeler, batch_data=batch_data)
138176
elif isinstance(modeler, TerminalComponentModeler):
139-
modeler_data = compose_terminal_modeler_data(modeler=modeler, batch_data=batch_data)
177+
modeler_data = compose_terminal_modeler_data_from_batch_data(modeler=modeler, batch_data=batch_data)
140178
else:
141179
raise TypeError(f"Unsupported modeler type: {type(modeler).__name__}")
142180

143181
return modeler_data
144182

145183

184+
def create_batch(
185+
modeler: ComponentModelerType,
186+
path_dir: str = DEFAULT_DATA_DIR,
187+
parent_batch_id: str = None,
188+
group_id: str = None,
189+
file_name: str = "batch.hdf5",
190+
**kwargs,
191+
) -> Batch:
192+
"""Creates a simulation Batch from a component modeler and saves it to a file.
193+
194+
Args:
195+
modeler: The component modeler that defines the set of simulations.
196+
path_dir: Directory where the batch file will be saved.
197+
file_name: Name for the HDF5 file where the batch is stored.
198+
**kwargs: Additional keyword arguments passed to the `Batch` constructor.
199+
200+
Returns:
201+
The configured `Batch` object ready for execution.
202+
"""
203+
filepath = os.path.join(path_dir, file_name)
204+
205+
if parent_batch_id is not None:
206+
parent_task_dict = dict()
207+
for key in modeler.sim_dict.keys():
208+
parent_task_dict[key] = (parent_batch_id,)
209+
else:
210+
parent_task_dict = None
211+
212+
if group_id is not None:
213+
group_id_dict = dict()
214+
for key in modeler.sim_dict.keys():
215+
group_id_dict[key] = (group_id,)
216+
else:
217+
group_id_dict = None
218+
219+
batch = Batch(
220+
simulations=modeler.sim_dict,
221+
parent_tasks=parent_task_dict,
222+
group_ids=group_id_dict,
223+
**kwargs,
224+
)
225+
batch.to_file(filepath)
226+
return batch
227+
146228
def run(
147229
modeler: ComponentModelerType,
148230
path_dir: str = DEFAULT_DATA_DIR,

0 commit comments

Comments
 (0)