|
| 1 | +"""Test the parameter sweep plugin.""" |
| 2 | +import pytest |
| 3 | +import numpy as np |
| 4 | +import tidy3d as td |
| 5 | +import matplotlib.pyplot as plt |
| 6 | +import scipy.stats.qmc as qmc |
| 7 | + |
| 8 | +import tidy3d.web as web |
| 9 | +from tidy3d.components.base import Tidy3dBaseModel |
| 10 | + |
| 11 | +from tidy3d.plugins import design as tdd |
| 12 | +from tidy3d.plugins.design.method import MethodIndependent |
| 13 | + |
| 14 | +from ..utils import run_emulated, log_capture, assert_log_level |
| 15 | + |
| 16 | + |
| 17 | +SWEEP_METHODS = dict( |
| 18 | + grid=tdd.MethodGrid(), |
| 19 | + monte_carlo=tdd.MethodMonteCarlo(num_points=5), |
| 20 | + custom=tdd.MethodRandomCustom(num_points=5, sampler=qmc.Halton(d=3)), |
| 21 | + random=tdd.MethodRandom(num_points=5), # TODO: remove this if not used |
| 22 | +) |
| 23 | + |
| 24 | + |
| 25 | +@pytest.mark.parametrize("sweep_method", SWEEP_METHODS.values(), ids=SWEEP_METHODS.keys()) |
| 26 | +def test_sweep(sweep_method, monkeypatch, ids=[]): |
| 27 | + # Problem, simulate scattering cross section of sphere ensemble |
| 28 | + # simulation consists of `num_spheres` spheres of radius `radius`. |
| 29 | + # use defines `scs` function to set up and run simulation as function of inputs. |
| 30 | + # then postprocesses the data to give the SCS. |
| 31 | + |
| 32 | + monkeypatch.setattr(web, "run", run_emulated) |
| 33 | + |
| 34 | + def emulated_batch_run(self, simulations, path_dir: str = None, **kwargs): |
| 35 | + data_dict = {task_name: run_emulated(sim) for task_name, sim in simulations.items()} |
| 36 | + task_ids = dict(zip(simulations.keys(), data_dict.keys())) |
| 37 | + |
| 38 | + class BatchDataEmulated(Tidy3dBaseModel): |
| 39 | + """Emulated BatchData object that just returns stored emulated data.""" |
| 40 | + |
| 41 | + data_dict: dict |
| 42 | + task_ids: dict |
| 43 | + |
| 44 | + def items(self): |
| 45 | + for task_name, sim_data in self.data_dict.items(): |
| 46 | + yield task_name, sim_data |
| 47 | + |
| 48 | + def __getitem__(self, task_name): |
| 49 | + return self.data_dict[task_name] |
| 50 | + |
| 51 | + return BatchDataEmulated(data_dict=data_dict, task_ids=task_ids) |
| 52 | + |
| 53 | + monkeypatch.setattr(MethodIndependent, "_run_batch", emulated_batch_run) |
| 54 | + |
| 55 | + # STEP1: define your design function (inputs and outputs) |
| 56 | + |
| 57 | + def scs_pre(radius: float, num_spheres: int, tag: str) -> td.Simulation: |
| 58 | + """Preprocessing function (make simulation)""" |
| 59 | + |
| 60 | + # set up simulation |
| 61 | + spheres = [] |
| 62 | + |
| 63 | + for i in range(int(num_spheres)): |
| 64 | + spheres.append( |
| 65 | + td.Structure( |
| 66 | + geometry=td.Sphere(radius=radius), |
| 67 | + medium=td.PEC, |
| 68 | + ) |
| 69 | + ) |
| 70 | + |
| 71 | + mnt = td.FieldMonitor( |
| 72 | + size=(0, 0, 0), |
| 73 | + center=(0, 0, 0), |
| 74 | + freqs=[2e14], |
| 75 | + name="field", |
| 76 | + ) |
| 77 | + |
| 78 | + return td.Simulation( |
| 79 | + size=(1, 1, 1), |
| 80 | + structures=spheres, |
| 81 | + grid_spec=td.GridSpec.auto(wavelength=1.0), |
| 82 | + run_time=1e-12, |
| 83 | + monitors=[mnt], |
| 84 | + ) |
| 85 | + |
| 86 | + def scs_post(sim_data: td.SimulationData) -> float: |
| 87 | + """Postprocessing function (analyze simulation data)""" |
| 88 | + |
| 89 | + mnt_data = sim_data["field"] |
| 90 | + ex_values = mnt_data.Ex.values |
| 91 | + |
| 92 | + # generate a random number to add some variance to data |
| 93 | + np.random.seed(hash(sim_data) % 1000) |
| 94 | + |
| 95 | + return np.sum(np.square(np.abs(ex_values))) + np.random.random() |
| 96 | + |
| 97 | + def scs_pre_multi(*args, **kwargs): |
| 98 | + sim = scs_pre(*args, **kwargs) |
| 99 | + |
| 100 | + return [sim, sim, sim] |
| 101 | + |
| 102 | + def scs_post_multi(*sim_datas): |
| 103 | + vals = [scs_post(sim_data) for sim_data in sim_datas] |
| 104 | + return np.mean(vals) |
| 105 | + |
| 106 | + def scs_pre_dict(*args, **kwargs): |
| 107 | + sims = scs_pre_multi(*args, **kwargs) |
| 108 | + keys = "abc" |
| 109 | + return dict(zip(keys, sims)) |
| 110 | + |
| 111 | + def scs_post_dict(a=None, b=None, c=None): |
| 112 | + sims = [a, b, c] |
| 113 | + return scs_post_multi(*sims) |
| 114 | + |
| 115 | + def scs(radius: float, num_spheres: int, tag: str) -> float: |
| 116 | + """End to end function.""" |
| 117 | + |
| 118 | + sim = scs_pre(radius=radius, num_spheres=num_spheres, tag=tag) |
| 119 | + |
| 120 | + # run simulation |
| 121 | + sim_data = run_emulated(sim, task_name=f"SWEEP_{tag}") |
| 122 | + |
| 123 | + # postprocess |
| 124 | + return scs_post(sim_data=sim_data) |
| 125 | + |
| 126 | + # STEP2: define your design problem |
| 127 | + |
| 128 | + radius_variable = tdd.ParameterFloat( |
| 129 | + name="radius", |
| 130 | + span=(0, 1.5), |
| 131 | + num_points=5, # note: only used for MethodGrid |
| 132 | + ) |
| 133 | + |
| 134 | + num_spheres_variable = tdd.ParameterInt( |
| 135 | + name="num_spheres", |
| 136 | + span=(0, 3), |
| 137 | + ) |
| 138 | + |
| 139 | + tag_variable = tdd.ParameterAny(name="tag", allowed_values=("tag1", "tag2", "tag3")) |
| 140 | + |
| 141 | + design_space = tdd.DesignSpace( |
| 142 | + parameters=[radius_variable, num_spheres_variable, tag_variable], |
| 143 | + method=sweep_method, |
| 144 | + name="sphere CS", |
| 145 | + ) |
| 146 | + |
| 147 | + # STEP3: Run your design problem |
| 148 | + |
| 149 | + # either supply generic function and run one by one |
| 150 | + sweep_results = design_space.run(scs) |
| 151 | + |
| 152 | + # or supply function factored into pre and post and run in batch |
| 153 | + sweep_results2 = design_space.run_batch(scs_pre, scs_post) |
| 154 | + |
| 155 | + sweep_results3 = design_space.run_batch(scs_pre_multi, scs_post_multi) |
| 156 | + |
| 157 | + sweep_results4 = design_space.run_batch(scs_pre_dict, scs_post_dict) |
| 158 | + |
| 159 | + sel_kwargs_0 = dict(zip(sweep_results.dims, sweep_results.coords[0])) |
| 160 | + sweep_results.sel(**sel_kwargs_0) |
| 161 | + |
| 162 | + print(sweep_results.to_dataframe().head(10)) |
| 163 | + |
| 164 | + im = sweep_results.to_dataframe().plot.hexbin(x="num_spheres", y="radius", C="output") |
| 165 | + im = sweep_results.to_dataframe().plot.scatter(x="num_spheres", y="radius", c="output") |
| 166 | + plt.close() |
| 167 | + |
| 168 | + design_space2 = tdd.DesignSpace( |
| 169 | + parameters=[radius_variable, num_spheres_variable, tag_variable], |
| 170 | + method=tdd.MethodMonteCarlo(num_points=3), |
| 171 | + name="sphere CS", |
| 172 | + ) |
| 173 | + |
| 174 | + sweep_results_other = design_space2.run(scs) |
| 175 | + |
| 176 | + results_combined = sweep_results.combine(sweep_results_other) |
| 177 | + results_combined = sweep_results + sweep_results_other |
| 178 | + |
| 179 | + # STEP4: modify the sweep results |
| 180 | + |
| 181 | + sweep_results = sweep_results.add( |
| 182 | + fn_args={"radius": 1.2, "num_spheres": 5, "tag": "tag2"}, value=1.9 |
| 183 | + ) |
| 184 | + |
| 185 | + sweep_results = sweep_results.delete(fn_args={"num_spheres": 5, "tag": "tag2", "radius": 1.2}) |
| 186 | + |
| 187 | + sweep_results_df = sweep_results.to_dataframe() |
| 188 | + |
| 189 | + sweep_results_2 = tdd.Result.from_dataframe(sweep_results_df) |
| 190 | + sweep_results_3 = tdd.Result.from_dataframe(sweep_results_df, dims=sweep_results.dims) |
| 191 | + |
| 192 | + assert sweep_results == sweep_results_2 == sweep_results_3 |
| 193 | + |
| 194 | + # VALIDATE PROPER DATAFRAME HEADERS AND DATA STORAGE |
| 195 | + |
| 196 | + # make sure returning a float uses the proper output column header |
| 197 | + float_label = tdd.Result.default_value_keys(1.0)[0] |
| 198 | + assert float_label in sweep_results_df, "didn't assign column header properly for float" |
| 199 | + |
| 200 | + # make sure returning a dict uses the keys as output column headers |
| 201 | + |
| 202 | + labels = ["label1", "label2"] |
| 203 | + |
| 204 | + def scs_dict(*args, **kwargs): |
| 205 | + output = scs(*args, **kwargs) |
| 206 | + return dict(zip(labels, len(labels) * [output])) |
| 207 | + |
| 208 | + df = design_space.run(scs_dict).to_dataframe() |
| 209 | + for label in labels: |
| 210 | + assert label in df, "dict key not parsed properly as column header" |
| 211 | + for value in df[label]: |
| 212 | + assert not isinstance(value, dict), "dict saved instead of value" |
| 213 | + |
| 214 | + # make sure returning a list assigns column labels properly |
| 215 | + |
| 216 | + num_outputs = 3 |
| 217 | + label_keys = tdd.Result.default_value_keys(num_outputs * [0.0]) |
| 218 | + |
| 219 | + def scs_list(*args, **kwargs): |
| 220 | + output = scs(*args, **kwargs) |
| 221 | + return num_outputs * [output] |
| 222 | + |
| 223 | + df = design_space.run(scs_list).to_dataframe() |
| 224 | + for label in label_keys: |
| 225 | + assert label in df, "dict key not parsed properly as column header" |
| 226 | + for value in df[label]: |
| 227 | + assert not isinstance(value, (tuple, list)), "dict saved instead of value" |
| 228 | + |
| 229 | + |
| 230 | +def test_method_custom_validators(): |
| 231 | + """Test the MethodRandomCustom validation performs as expected.""" |
| 232 | + |
| 233 | + d = 3 |
| 234 | + |
| 235 | + # expected case |
| 236 | + class SamplerWorks: |
| 237 | + def random(self, n): |
| 238 | + return np.random.random((n, d)) |
| 239 | + |
| 240 | + tdd.MethodRandomCustom(num_points=5, sampler=SamplerWorks()), |
| 241 | + |
| 242 | + # missing random method case |
| 243 | + class SamplerNoRandom: |
| 244 | + pass |
| 245 | + |
| 246 | + with pytest.raises(ValueError): |
| 247 | + tdd.MethodRandomCustom(num_points=5, sampler=SamplerNoRandom()), |
| 248 | + |
| 249 | + # random method gives a list |
| 250 | + class SamplerList: |
| 251 | + def random(self, n): |
| 252 | + return np.random.random((n, d)).tolist() |
| 253 | + |
| 254 | + with pytest.raises(ValueError): |
| 255 | + tdd.MethodRandomCustom(num_points=5, sampler=SamplerList()), |
| 256 | + |
| 257 | + # random method gives wrong number of dimensions |
| 258 | + class SamplerWrongDims: |
| 259 | + def random(self, n): |
| 260 | + return np.random.random((n, d, d)) |
| 261 | + |
| 262 | + with pytest.raises(ValueError): |
| 263 | + tdd.MethodRandomCustom(num_points=5, sampler=SamplerWrongDims()), |
| 264 | + |
| 265 | + # random method gives wrong first dimension length |
| 266 | + class SamplerWrongShape: |
| 267 | + def random(self, n): |
| 268 | + return np.random.random((n + 1, d)) |
| 269 | + |
| 270 | + with pytest.raises(ValueError): |
| 271 | + tdd.MethodRandomCustom(num_points=5, sampler=SamplerWrongShape()), |
| 272 | + |
| 273 | + # random method gives floats outside of range of 0, 1 |
| 274 | + class SamplerOutOfRange: |
| 275 | + def random(self, n): |
| 276 | + return 3 * np.random.random((n, d)) - 1 |
| 277 | + |
| 278 | + with pytest.raises(ValueError): |
| 279 | + tdd.MethodRandomCustom(num_points=5, sampler=SamplerOutOfRange()), |
| 280 | + |
| 281 | + |
| 282 | +@pytest.mark.parametrize( |
| 283 | + "monte_carlo_warning, log_level_expected", |
| 284 | + [(True, "WARNING"), (False, None)], |
| 285 | + ids=["warn_monte_carlo", "no_warn_monte_carlo"], |
| 286 | +) |
| 287 | +def test_method_random_warning(log_capture, monte_carlo_warning, log_level_expected): |
| 288 | + """Test that method random validation / warning works as expected.""" |
| 289 | + |
| 290 | + method = tdd.MethodRandom(num_points=10, monte_carlo_warning=monte_carlo_warning) |
| 291 | + assert_log_level(log_capture, log_level_expected) |
| 292 | + |
| 293 | + |
| 294 | +@pytest.mark.parametrize( |
| 295 | + "parameter", |
| 296 | + [ |
| 297 | + tdd.ParameterAny(name="test", allowed_values=("a", "b", "c")), |
| 298 | + tdd.ParameterFloat(name="test", span=(0, 1)), |
| 299 | + tdd.ParameterInt(name="test", span=(1, 5)), |
| 300 | + ], |
| 301 | + ids=["any", "float", "int"], |
| 302 | +) |
| 303 | +def test_random_sampling(parameter): |
| 304 | + """just make sure sample_random still works in case we need it.""" |
| 305 | + parameter.sample_random(10) |
0 commit comments