Skip to content

Commit bf8b6dd

Browse files
committed
introduce design plugin
1 parent 1066a3c commit bf8b6dd

File tree

8 files changed

+1543
-0
lines changed

8 files changed

+1543
-0
lines changed

CHANGELOG.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
### Added
99
- `ModeData.dispersion` and `ModeSolverData.dispersion` are calculated together with the group index.
10+
- `tidy3d.plugins.design` tool to explore user-defined design spaces.
11+
12+
### Changed
13+
14+
### Fixed
15+
1016

1117
## [2.5.0] - 2023-12-13
1218

@@ -74,6 +80,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
7480
- Ensure that mode solver fields are returned in single precision if `ModeSolver.ModeSpec.precision == "single"`.
7581
- If there are no adjoint sources for a simulation involved in an objective function, make a mock source with zero amplitude and warn user.
7682

83+
## [2.5.0rc1] - 2023-10-10
84+
85+
### Added
86+
- Time zone in webAPI logging output.
87+
- Class `Scene` consisting of a background medium and structures for easier drafting and visualization of simulation setups as well as transferring such information between different simulations.
88+
- Solver for thermal simulation (see `HeatSimulation` and related classes).
89+
- Specification of material thermal properties in medium classes through an optional field `.heat_spec`.
90+
91+
### Changed
92+
- Internal refactor of Web API functionality.
93+
- `Geometry.from_gds` doesn't create unnecessary groups of single elements.
94+
7795
## [2.4.3] - 2023-10-16
7896

7997
### Added

tests/test_plugins/test_design.py

Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
"""Test the parameter sweep plugin."""
2+
import pytest
3+
import numpy as np
4+
import tidy3d as td
5+
import matplotlib.pyplot as plt
6+
import scipy.stats.qmc as qmc
7+
8+
import tidy3d.web as web
9+
from tidy3d.components.base import Tidy3dBaseModel
10+
11+
from tidy3d.plugins import design as tdd
12+
from tidy3d.plugins.design.method import MethodIndependent
13+
14+
from ..utils import run_emulated, log_capture, assert_log_level
15+
16+
17+
SWEEP_METHODS = dict(
18+
grid=tdd.MethodGrid(),
19+
monte_carlo=tdd.MethodMonteCarlo(num_points=5),
20+
custom=tdd.MethodRandomCustom(num_points=5, sampler=qmc.Halton(d=3)),
21+
random=tdd.MethodRandom(num_points=5), # TODO: remove this if not used
22+
)
23+
24+
25+
@pytest.mark.parametrize("sweep_method", SWEEP_METHODS.values(), ids=SWEEP_METHODS.keys())
26+
def test_sweep(sweep_method, monkeypatch, ids=[]):
27+
# Problem, simulate scattering cross section of sphere ensemble
28+
# simulation consists of `num_spheres` spheres of radius `radius`.
29+
# use defines `scs` function to set up and run simulation as function of inputs.
30+
# then postprocesses the data to give the SCS.
31+
32+
monkeypatch.setattr(web, "run", run_emulated)
33+
34+
def emulated_batch_run(self, simulations, path_dir: str = None, **kwargs):
35+
data_dict = {task_name: run_emulated(sim) for task_name, sim in simulations.items()}
36+
task_ids = dict(zip(simulations.keys(), data_dict.keys()))
37+
38+
class BatchDataEmulated(Tidy3dBaseModel):
39+
"""Emulated BatchData object that just returns stored emulated data."""
40+
41+
data_dict: dict
42+
task_ids: dict
43+
44+
def items(self):
45+
for task_name, sim_data in self.data_dict.items():
46+
yield task_name, sim_data
47+
48+
def __getitem__(self, task_name):
49+
return self.data_dict[task_name]
50+
51+
return BatchDataEmulated(data_dict=data_dict, task_ids=task_ids)
52+
53+
monkeypatch.setattr(MethodIndependent, "_run_batch", emulated_batch_run)
54+
55+
# STEP1: define your design function (inputs and outputs)
56+
57+
def scs_pre(radius: float, num_spheres: int, tag: str) -> td.Simulation:
58+
"""Preprocessing function (make simulation)"""
59+
60+
# set up simulation
61+
spheres = []
62+
63+
for i in range(int(num_spheres)):
64+
spheres.append(
65+
td.Structure(
66+
geometry=td.Sphere(radius=radius),
67+
medium=td.PEC,
68+
)
69+
)
70+
71+
mnt = td.FieldMonitor(
72+
size=(0, 0, 0),
73+
center=(0, 0, 0),
74+
freqs=[2e14],
75+
name="field",
76+
)
77+
78+
return td.Simulation(
79+
size=(1, 1, 1),
80+
structures=spheres,
81+
grid_spec=td.GridSpec.auto(wavelength=1.0),
82+
run_time=1e-12,
83+
monitors=[mnt],
84+
)
85+
86+
def scs_post(sim_data: td.SimulationData) -> float:
87+
"""Postprocessing function (analyze simulation data)"""
88+
89+
mnt_data = sim_data["field"]
90+
ex_values = mnt_data.Ex.values
91+
92+
# generate a random number to add some variance to data
93+
np.random.seed(hash(sim_data) % 1000)
94+
95+
return np.sum(np.square(np.abs(ex_values))) + np.random.random()
96+
97+
def scs_pre_multi(*args, **kwargs):
98+
sim = scs_pre(*args, **kwargs)
99+
100+
return [sim, sim, sim]
101+
102+
def scs_post_multi(*sim_datas):
103+
vals = [scs_post(sim_data) for sim_data in sim_datas]
104+
return np.mean(vals)
105+
106+
def scs_pre_dict(*args, **kwargs):
107+
sims = scs_pre_multi(*args, **kwargs)
108+
keys = "abc"
109+
return dict(zip(keys, sims))
110+
111+
def scs_post_dict(a=None, b=None, c=None):
112+
sims = [a, b, c]
113+
return scs_post_multi(*sims)
114+
115+
def scs(radius: float, num_spheres: int, tag: str) -> float:
116+
"""End to end function."""
117+
118+
sim = scs_pre(radius=radius, num_spheres=num_spheres, tag=tag)
119+
120+
# run simulation
121+
sim_data = run_emulated(sim, task_name=f"SWEEP_{tag}")
122+
123+
# postprocess
124+
return scs_post(sim_data=sim_data)
125+
126+
# STEP2: define your design problem
127+
128+
radius_variable = tdd.ParameterFloat(
129+
name="radius",
130+
span=(0, 1.5),
131+
num_points=5, # note: only used for MethodGrid
132+
)
133+
134+
num_spheres_variable = tdd.ParameterInt(
135+
name="num_spheres",
136+
span=(0, 3),
137+
)
138+
139+
tag_variable = tdd.ParameterAny(name="tag", allowed_values=("tag1", "tag2", "tag3"))
140+
141+
design_space = tdd.DesignSpace(
142+
parameters=[radius_variable, num_spheres_variable, tag_variable],
143+
method=sweep_method,
144+
name="sphere CS",
145+
)
146+
147+
# STEP3: Run your design problem
148+
149+
# either supply generic function and run one by one
150+
sweep_results = design_space.run(scs)
151+
152+
# or supply function factored into pre and post and run in batch
153+
sweep_results2 = design_space.run_batch(scs_pre, scs_post)
154+
155+
sweep_results3 = design_space.run_batch(scs_pre_multi, scs_post_multi)
156+
157+
sweep_results4 = design_space.run_batch(scs_pre_dict, scs_post_dict)
158+
159+
sel_kwargs_0 = dict(zip(sweep_results.dims, sweep_results.coords[0]))
160+
sweep_results.sel(**sel_kwargs_0)
161+
162+
print(sweep_results.to_dataframe().head(10))
163+
164+
im = sweep_results.to_dataframe().plot.hexbin(x="num_spheres", y="radius", C="output")
165+
im = sweep_results.to_dataframe().plot.scatter(x="num_spheres", y="radius", c="output")
166+
plt.close()
167+
168+
design_space2 = tdd.DesignSpace(
169+
parameters=[radius_variable, num_spheres_variable, tag_variable],
170+
method=tdd.MethodMonteCarlo(num_points=3),
171+
name="sphere CS",
172+
)
173+
174+
sweep_results_other = design_space2.run(scs)
175+
176+
results_combined = sweep_results.combine(sweep_results_other)
177+
results_combined = sweep_results + sweep_results_other
178+
179+
# STEP4: modify the sweep results
180+
181+
sweep_results = sweep_results.add(
182+
fn_args={"radius": 1.2, "num_spheres": 5, "tag": "tag2"}, value=1.9
183+
)
184+
185+
sweep_results = sweep_results.delete(fn_args={"num_spheres": 5, "tag": "tag2", "radius": 1.2})
186+
187+
sweep_results_df = sweep_results.to_dataframe()
188+
189+
sweep_results_2 = tdd.Result.from_dataframe(sweep_results_df)
190+
sweep_results_3 = tdd.Result.from_dataframe(sweep_results_df, dims=sweep_results.dims)
191+
192+
assert sweep_results == sweep_results_2 == sweep_results_3
193+
194+
# VALIDATE PROPER DATAFRAME HEADERS AND DATA STORAGE
195+
196+
# make sure returning a float uses the proper output column header
197+
float_label = tdd.Result.default_value_keys(1.0)[0]
198+
assert float_label in sweep_results_df, "didn't assign column header properly for float"
199+
200+
# make sure returning a dict uses the keys as output column headers
201+
202+
labels = ["label1", "label2"]
203+
204+
def scs_dict(*args, **kwargs):
205+
output = scs(*args, **kwargs)
206+
return dict(zip(labels, len(labels) * [output]))
207+
208+
df = design_space.run(scs_dict).to_dataframe()
209+
for label in labels:
210+
assert label in df, "dict key not parsed properly as column header"
211+
for value in df[label]:
212+
assert not isinstance(value, dict), "dict saved instead of value"
213+
214+
# make sure returning a list assigns column labels properly
215+
216+
num_outputs = 3
217+
label_keys = tdd.Result.default_value_keys(num_outputs * [0.0])
218+
219+
def scs_list(*args, **kwargs):
220+
output = scs(*args, **kwargs)
221+
return num_outputs * [output]
222+
223+
df = design_space.run(scs_list).to_dataframe()
224+
for label in label_keys:
225+
assert label in df, "dict key not parsed properly as column header"
226+
for value in df[label]:
227+
assert not isinstance(value, (tuple, list)), "dict saved instead of value"
228+
229+
230+
def test_method_custom_validators():
231+
"""Test the MethodRandomCustom validation performs as expected."""
232+
233+
d = 3
234+
235+
# expected case
236+
class SamplerWorks:
237+
def random(self, n):
238+
return np.random.random((n, d))
239+
240+
tdd.MethodRandomCustom(num_points=5, sampler=SamplerWorks()),
241+
242+
# missing random method case
243+
class SamplerNoRandom:
244+
pass
245+
246+
with pytest.raises(ValueError):
247+
tdd.MethodRandomCustom(num_points=5, sampler=SamplerNoRandom()),
248+
249+
# random method gives a list
250+
class SamplerList:
251+
def random(self, n):
252+
return np.random.random((n, d)).tolist()
253+
254+
with pytest.raises(ValueError):
255+
tdd.MethodRandomCustom(num_points=5, sampler=SamplerList()),
256+
257+
# random method gives wrong number of dimensions
258+
class SamplerWrongDims:
259+
def random(self, n):
260+
return np.random.random((n, d, d))
261+
262+
with pytest.raises(ValueError):
263+
tdd.MethodRandomCustom(num_points=5, sampler=SamplerWrongDims()),
264+
265+
# random method gives wrong first dimension length
266+
class SamplerWrongShape:
267+
def random(self, n):
268+
return np.random.random((n + 1, d))
269+
270+
with pytest.raises(ValueError):
271+
tdd.MethodRandomCustom(num_points=5, sampler=SamplerWrongShape()),
272+
273+
# random method gives floats outside of range of 0, 1
274+
class SamplerOutOfRange:
275+
def random(self, n):
276+
return 3 * np.random.random((n, d)) - 1
277+
278+
with pytest.raises(ValueError):
279+
tdd.MethodRandomCustom(num_points=5, sampler=SamplerOutOfRange()),
280+
281+
282+
@pytest.mark.parametrize(
283+
"monte_carlo_warning, log_level_expected",
284+
[(True, "WARNING"), (False, None)],
285+
ids=["warn_monte_carlo", "no_warn_monte_carlo"],
286+
)
287+
def test_method_random_warning(log_capture, monte_carlo_warning, log_level_expected):
288+
"""Test that method random validation / warning works as expected."""
289+
290+
method = tdd.MethodRandom(num_points=10, monte_carlo_warning=monte_carlo_warning)
291+
assert_log_level(log_capture, log_level_expected)
292+
293+
294+
@pytest.mark.parametrize(
295+
"parameter",
296+
[
297+
tdd.ParameterAny(name="test", allowed_values=("a", "b", "c")),
298+
tdd.ParameterFloat(name="test", span=(0, 1)),
299+
tdd.ParameterInt(name="test", span=(1, 5)),
300+
],
301+
ids=["any", "float", "int"],
302+
)
303+
def test_random_sampling(parameter):
304+
"""just make sure sample_random still works in case we need it."""
305+
parameter.sample_random(10)

0 commit comments

Comments
 (0)