|
| 1 | +# test autograd for field sources when symmetry is used in the simulation |
| 2 | +from __future__ import annotations |
| 3 | + |
| 4 | +import operator |
| 5 | +import sys |
| 6 | + |
| 7 | +import autograd as ag |
| 8 | +import matplotlib.pylab as plt |
| 9 | +import numpy as np |
| 10 | +import pytest |
| 11 | + |
| 12 | +import tidy3d as td |
| 13 | +import tidy3d.web as web |
| 14 | + |
| 15 | +PLOT_SYMMETRY_COMPARISON = False |
| 16 | +NUM_FINITE_DIFFERENCE = 10 |
| 17 | +SAVE_FD_ADJ_DATA = False |
| 18 | +SAVE_FD_LOC = 0 |
| 19 | +SAVE_ADJ_LOC = 1 |
| 20 | +LOCAL_GRADIENT = False |
| 21 | +VERBOSE = False |
| 22 | +NUMERICAL_RESULTS_DATA_DIR = "./numerical_symmetry_test/" |
| 23 | +SHOW_PRINT_STATEMENTS = True |
| 24 | + |
| 25 | +RMS_THRESHOLD = 0.25 |
| 26 | + |
| 27 | +if PLOT_SYMMETRY_COMPARISON: |
| 28 | + pytestmark = pytest.mark.usefixtures("mpl_config_interactive") |
| 29 | +else: |
| 30 | + pytestmark = pytest.mark.usefixtures("mpl_config_noninteractive") |
| 31 | + |
| 32 | +if SHOW_PRINT_STATEMENTS: |
| 33 | + sys.stdout = sys.stderr |
| 34 | + |
| 35 | + |
| 36 | +FINITE_DIFF_PERM_SEED = 1.5**2 |
| 37 | +MESH_FACTOR_DESIGN = 30.0 |
| 38 | + |
| 39 | + |
| 40 | +def get_sim_geometry(mesh_wvl_um): |
| 41 | + return td.Box(size=(5 * mesh_wvl_um, 5 * mesh_wvl_um, 7 * mesh_wvl_um), center=(0, 0, 0)) |
| 42 | + |
| 43 | + |
| 44 | +def make_base_sim( |
| 45 | + mesh_wvl_um, |
| 46 | + adj_wvl_um, |
| 47 | + monitor_size_wvl, |
| 48 | + box_for_override, |
| 49 | + symmetry, |
| 50 | + monitor_bg_index=1.0, |
| 51 | + run_time=1e-11, |
| 52 | +): |
| 53 | + sim_geometry = get_sim_geometry(mesh_wvl_um) |
| 54 | + sim_size_um = sim_geometry.size |
| 55 | + sim_center_um = sim_geometry.center |
| 56 | + |
| 57 | + boundary_spec = td.BoundarySpec( |
| 58 | + x=td.Boundary.pml(), |
| 59 | + y=td.Boundary.pml(), |
| 60 | + z=td.Boundary.pml(), |
| 61 | + ) |
| 62 | + |
| 63 | + dl_design = mesh_wvl_um / MESH_FACTOR_DESIGN |
| 64 | + |
| 65 | + mesh_overrides = [] |
| 66 | + mesh_overrides.extend( |
| 67 | + [ |
| 68 | + td.MeshOverrideStructure( |
| 69 | + geometry=box_for_override, |
| 70 | + dl=[dl_design, dl_design, dl_design], |
| 71 | + ), |
| 72 | + ] |
| 73 | + ) |
| 74 | + |
| 75 | + src_size = sim_size_um[0:2] + (0,) |
| 76 | + |
| 77 | + wl_min_src_um = 0.9 * adj_wvl_um |
| 78 | + wl_max_src_um = 1.1 * adj_wvl_um |
| 79 | + |
| 80 | + fwidth_src = td.C_0 * ((1.0 / wl_min_src_um) - (1.0 / wl_max_src_um)) |
| 81 | + freq0 = td.C_0 / adj_wvl_um |
| 82 | + |
| 83 | + pulse = td.GaussianPulse(freq0=freq0, fwidth=fwidth_src) |
| 84 | + |
| 85 | + src = td.PlaneWave( |
| 86 | + center=(0, 0, -2 * mesh_wvl_um), |
| 87 | + size=(td.inf, td.inf, 0), |
| 88 | + direction="+", |
| 89 | + pol_angle=0, |
| 90 | + angle_theta=0, |
| 91 | + source_time=pulse, |
| 92 | + ) |
| 93 | + |
| 94 | + field_monitor = td.FieldMonitor( |
| 95 | + center=(0, 0, 0.25 * sim_size_um[2]), |
| 96 | + size=tuple(dim * mesh_wvl_um for dim in monitor_size_wvl), |
| 97 | + name="monitor_fields", |
| 98 | + freqs=[freq0], |
| 99 | + ) |
| 100 | + |
| 101 | + monitor_index_block = td.Box( |
| 102 | + center=(0, 0, 0.25 * sim_size_um[2] + mesh_wvl_um), |
| 103 | + size=(*tuple(2 * size for size in sim_size_um[0:2]), mesh_wvl_um + 0.5 * sim_size_um[2]), |
| 104 | + ) |
| 105 | + monitor_index_block_structure = td.Structure( |
| 106 | + geometry=monitor_index_block, medium=td.Medium(permittivity=monitor_bg_index**2) |
| 107 | + ) |
| 108 | + |
| 109 | + sim_base = td.Simulation( |
| 110 | + center=sim_center_um, |
| 111 | + size=sim_size_um, |
| 112 | + grid_spec=td.GridSpec.auto( |
| 113 | + min_steps_per_wvl=30, |
| 114 | + wavelength=mesh_wvl_um, |
| 115 | + override_structures=mesh_overrides, |
| 116 | + ), |
| 117 | + structures=[monitor_index_block_structure], |
| 118 | + sources=[src], |
| 119 | + monitors=[field_monitor], |
| 120 | + run_time=run_time, |
| 121 | + boundary_spec=boundary_spec, |
| 122 | + subpixel=True, |
| 123 | + symmetry=symmetry, |
| 124 | + ) |
| 125 | + |
| 126 | + return sim_base |
| 127 | + |
| 128 | + |
| 129 | +def create_objective_functions(geometry, create_sim_base, eval_fn, sim_path_dir): |
| 130 | + def objective_(perm_array, symmetry): |
| 131 | + sim_base = create_sim_base(symmetry) |
| 132 | + |
| 133 | + block_structure = td.Structure.from_permittivity_array( |
| 134 | + eps_data=perm_array, |
| 135 | + geometry=geometry, |
| 136 | + ) |
| 137 | + |
| 138 | + sim_with_block = sim_base.updated_copy(structures=(*sim_base.structures, block_structure)) |
| 139 | + |
| 140 | + sim_data = web.run( |
| 141 | + sim_with_block, |
| 142 | + task_name="symmetry_field_testing", |
| 143 | + local_gradient=LOCAL_GRADIENT, |
| 144 | + verbose=VERBOSE, |
| 145 | + ) |
| 146 | + |
| 147 | + objective_val = eval_fn(sim_data) |
| 148 | + |
| 149 | + return objective_val |
| 150 | + |
| 151 | + def objective_no_symmetry(perm_array): |
| 152 | + return objective_(perm_array=perm_array, symmetry=(0, 0, 0)) |
| 153 | + |
| 154 | + def objective_x_symmetry(perm_array): |
| 155 | + return objective_(perm_array=perm_array, symmetry=(-1, 0, 0)) |
| 156 | + |
| 157 | + def objective_y_symmetry(perm_array): |
| 158 | + return objective_(perm_array=perm_array, symmetry=(0, 1, 0)) |
| 159 | + |
| 160 | + def objective_xy_symmetry(perm_array): |
| 161 | + return objective_(perm_array=perm_array, symmetry=(-1, 1, 0)) |
| 162 | + |
| 163 | + return objective_no_symmetry, objective_x_symmetry, objective_y_symmetry, objective_xy_symmetry |
| 164 | + |
| 165 | + |
| 166 | +def make_eval_fns(monitor_size_wvl): |
| 167 | + num_nonzero_spatial_dims = 3 - np.sum(np.isclose(monitor_size_wvl, 0)) |
| 168 | + |
| 169 | + def intensity(sim_data): |
| 170 | + field_data = sim_data["monitor_fields"] |
| 171 | + shape_x, shape_y, shape_z, *_ = field_data.Ex.values.shape |
| 172 | + |
| 173 | + return np.sum(np.abs(field_data.Ex.values) ** 2 + np.abs(field_data.Ey.values) ** 2) |
| 174 | + |
| 175 | + eval_fns = [intensity] |
| 176 | + eval_fn_names = ["intensity"] |
| 177 | + |
| 178 | + if num_nonzero_spatial_dims == 2: |
| 179 | + |
| 180 | + def flux(sim_data): |
| 181 | + field_data = sim_data["monitor_fields"] |
| 182 | + |
| 183 | + return np.sum(field_data.flux.values) |
| 184 | + |
| 185 | + eval_fns.append(flux) |
| 186 | + eval_fn_names.append("flux") |
| 187 | + |
| 188 | + return eval_fns, eval_fn_names |
| 189 | + |
| 190 | + |
| 191 | +background_indices = [1.0] |
| 192 | +mesh_wvls_um = [1.55] |
| 193 | +adj_wvls_um = [1.55] |
| 194 | +monitor_sizes_3d_wvl = [(0.5, 0.5, 0)] |
| 195 | + |
| 196 | +field_symmetry_test_parameters = [] |
| 197 | + |
| 198 | +test_number = 0 |
| 199 | +for idx in range(len(mesh_wvls_um)): |
| 200 | + mesh_wvl_um = mesh_wvls_um[idx] |
| 201 | + adj_wvl_um = adj_wvls_um[idx] |
| 202 | + |
| 203 | + for monitor_size_wvl in monitor_sizes_3d_wvl: |
| 204 | + eval_fns, eval_fn_names = make_eval_fns(monitor_size_wvl) |
| 205 | + |
| 206 | + for monitor_bg_index in background_indices: |
| 207 | + for eval_fn_idx, eval_fn in enumerate(eval_fns): |
| 208 | + field_symmetry_test_parameters.append( |
| 209 | + { |
| 210 | + "mesh_wvl_um": mesh_wvl_um, |
| 211 | + "adj_wvl_um": adj_wvl_um, |
| 212 | + "monitor_size_wvl": monitor_size_wvl, |
| 213 | + "monitor_bg_index": monitor_bg_index, |
| 214 | + "eval_fn": eval_fn, |
| 215 | + "eval_fn_name": eval_fn_names[eval_fn_idx], |
| 216 | + "test_number": test_number, |
| 217 | + } |
| 218 | + ) |
| 219 | + |
| 220 | + test_number += 1 |
| 221 | + |
| 222 | + |
| 223 | +@pytest.mark.numerical |
| 224 | +@pytest.mark.parametrize( |
| 225 | + "field_symmetry_test_parameters, dir_name", |
| 226 | + zip( |
| 227 | + field_symmetry_test_parameters, |
| 228 | + ([NUMERICAL_RESULTS_DATA_DIR] if SAVE_FD_ADJ_DATA else [None]) |
| 229 | + * len(field_symmetry_test_parameters), |
| 230 | + ), |
| 231 | + indirect=["dir_name"], |
| 232 | +) |
| 233 | +def test_adjoint_difference_symmetry( |
| 234 | + field_symmetry_test_parameters, rng, tmp_path, create_directory |
| 235 | +): |
| 236 | + """Test the gradient is not affected by symmetry when using field sources.""" |
| 237 | + |
| 238 | + num_tests = 0 |
| 239 | + for monitor_size_wvl in monitor_sizes_3d_wvl: |
| 240 | + eval_fns, _ = make_eval_fns(monitor_size_wvl) |
| 241 | + num_tests += len(eval_fns) * len(background_indices) * len(mesh_wvls_um) |
| 242 | + |
| 243 | + test_results = np.zeros((2, NUM_FINITE_DIFFERENCE)) |
| 244 | + |
| 245 | + test_number = field_symmetry_test_parameters["test_number"] |
| 246 | + |
| 247 | + ( |
| 248 | + mesh_wvl_um, |
| 249 | + adj_wvl_um, |
| 250 | + monitor_size_wvl, |
| 251 | + monitor_bg_index, |
| 252 | + eval_fn, |
| 253 | + eval_fn_name, |
| 254 | + test_number, |
| 255 | + ) = operator.itemgetter( |
| 256 | + "mesh_wvl_um", |
| 257 | + "adj_wvl_um", |
| 258 | + "monitor_size_wvl", |
| 259 | + "monitor_bg_index", |
| 260 | + "eval_fn", |
| 261 | + "eval_fn_name", |
| 262 | + "test_number", |
| 263 | + )(field_symmetry_test_parameters) |
| 264 | + |
| 265 | + dim_um = mesh_wvl_um |
| 266 | + dim_um = mesh_wvl_um |
| 267 | + thickness_um = 0.5 * mesh_wvl_um |
| 268 | + block = td.Box(center=(0, 0, 0), size=(dim_um, dim_um, thickness_um)) |
| 269 | + |
| 270 | + dim = 1 + int(dim_um / (mesh_wvl_um / MESH_FACTOR_DESIGN)) |
| 271 | + Nz = 1 + int(thickness_um / (mesh_wvl_um / MESH_FACTOR_DESIGN)) |
| 272 | + |
| 273 | + sim_geometry = get_sim_geometry(mesh_wvl_um) |
| 274 | + |
| 275 | + box_for_override = td.Box( |
| 276 | + center=(0, 0, 0), size=sim_geometry.size[0:2] + (thickness_um + mesh_wvl_um,) |
| 277 | + ) |
| 278 | + |
| 279 | + eval_fns, eval_fn_names = make_eval_fns(monitor_size_wvl) |
| 280 | + |
| 281 | + sim_path_dir = tmp_path / f"test{test_number}" |
| 282 | + sim_path_dir.mkdir() |
| 283 | + |
| 284 | + objective_no_symmetry, objective_x_symmetry, objective_y_symmetry, objective_xy_symmetry = ( |
| 285 | + create_objective_functions( |
| 286 | + block, |
| 287 | + lambda symmetry, |
| 288 | + mesh_wvl_um=mesh_wvl_um, |
| 289 | + adj_wvl_um=adj_wvl_um, |
| 290 | + monitor_size_wvl=monitor_size_wvl, |
| 291 | + box_for_override=box_for_override, |
| 292 | + monitor_bg_index=monitor_bg_index: make_base_sim( |
| 293 | + mesh_wvl_um=mesh_wvl_um, |
| 294 | + adj_wvl_um=adj_wvl_um, |
| 295 | + monitor_size_wvl=monitor_size_wvl, |
| 296 | + box_for_override=box_for_override, |
| 297 | + monitor_bg_index=monitor_bg_index, |
| 298 | + symmetry=symmetry, |
| 299 | + ), |
| 300 | + eval_fn, |
| 301 | + sim_path_dir=str(sim_path_dir), |
| 302 | + ) |
| 303 | + ) |
| 304 | + |
| 305 | + obj_val_and_grad_no_symmetry = ag.value_and_grad(objective_no_symmetry) |
| 306 | + obj_val_and_grad_x_symmetry = ag.value_and_grad(objective_x_symmetry) |
| 307 | + obj_val_and_grad_y_symmetry = ag.value_and_grad(objective_y_symmetry) |
| 308 | + obj_val_and_grad_xy_symmetry = ag.value_and_grad(objective_xy_symmetry) |
| 309 | + |
| 310 | + objs_val_and_grad = [ |
| 311 | + obj_val_and_grad_no_symmetry, |
| 312 | + obj_val_and_grad_x_symmetry, |
| 313 | + obj_val_and_grad_y_symmetry, |
| 314 | + obj_val_and_grad_xy_symmetry, |
| 315 | + ] |
| 316 | + |
| 317 | + symmetries = ["none", "x", "y", "xy"] |
| 318 | + |
| 319 | + objs = [] |
| 320 | + adj_grads = [] |
| 321 | + |
| 322 | + perm_init = FINITE_DIFF_PERM_SEED * np.ones((dim, dim, Nz)) |
| 323 | + |
| 324 | + for obj_val_and_grad in objs_val_and_grad: |
| 325 | + obj, adj_grad = obj_val_and_grad(perm_init) |
| 326 | + |
| 327 | + objs.append(obj) |
| 328 | + adj_grads.append(np.array(adj_grad)) |
| 329 | + |
| 330 | + grad_data_base = adj_grads[0] / objs[0] |
| 331 | + for idx in range(1, len(adj_grads)): |
| 332 | + # field magnitudes can be different for different symmetries so we expect the gradients |
| 333 | + # to scale with the objecive values |
| 334 | + grad_data = adj_grads[idx] / objs[idx] |
| 335 | + |
| 336 | + mag_base = np.sqrt(np.mean(grad_data_base**2)) |
| 337 | + mag_compare = np.sqrt(np.mean(grad_data**2)) |
| 338 | + rms_error = np.sqrt(np.mean((grad_data_base - grad_data) ** 2)) |
| 339 | + |
| 340 | + if SHOW_PRINT_STATEMENTS: |
| 341 | + print(f"Testing {eval_fn_name} objective") |
| 342 | + print(f"Symmetry comparison: {symmetries[0]}, {symmetries[idx]}") |
| 343 | + print(f"RMS error (normalized): {rms_error / np.sqrt(mag_base * mag_compare)}") |
| 344 | + |
| 345 | + assert np.isclose(rms_error / np.sqrt(mag_base * mag_compare), 0.0, atol=0.075), ( |
| 346 | + "Expected adjoint gradients to be the same with and without symmetry" |
| 347 | + ) |
| 348 | + |
| 349 | + if PLOT_SYMMETRY_COMPARISON: |
| 350 | + plot_grad_data_base = np.squeeze(grad_data_base) |
| 351 | + plot_grad_data = np.squeeze(grad_data) |
| 352 | + plot_diff = plot_grad_data - plot_grad_data_base |
| 353 | + |
| 354 | + plt.subplot(1, 3, 1) |
| 355 | + plt.imshow(plot_grad_data_base[:, :, plot_grad_data_base.shape[2] // 2]) |
| 356 | + plt.title(f"Symmetry: {symmetries[0]}") |
| 357 | + plt.colorbar() |
| 358 | + plt.subplot(1, 3, 2) |
| 359 | + plt.imshow(plot_grad_data[:, :, plot_grad_data.shape[2] // 2]) |
| 360 | + plt.title(f"Symmetry: {symmetries[idx]}") |
| 361 | + plt.colorbar() |
| 362 | + plt.subplot(1, 3, 3) |
| 363 | + plt.imshow(plot_diff[:, :, plot_diff.shape[2] // 2]) |
| 364 | + plt.title("Difference") |
| 365 | + plt.colorbar() |
| 366 | + plt.show() |
0 commit comments