Skip to content

Commit a69361d

Browse files
tylerflexmomchil-flex
authored andcommitted
several adjoint fixes regarding JaxDataArray indexing and symmetry expansion
1 parent e7ae68d commit a69361d

File tree

7 files changed

+118
-52
lines changed

7 files changed

+118
-52
lines changed

CHANGELOG.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313

1414
### Fixed
1515
- Ensure same `Grid` is generated in forward and adjoint simulations by setting `GridSpec.wavelength` manually in adjoint.
16-
- Proper handling of `JaxBox` derivatives both for multi-cell and single cell thickness.
16+
- Properly handling of `JaxBox` derivatives both for multi-cell and single cell thickness.
17+
- Properly handle `JaxSimulation.monitors` with `.freqs` as `np.ndarray` in adjoint plugin.
18+
- Properly handle `JaxDataArray.sel()` with single coordinates and symmetry expansion.
19+
- Properly handle `JaxDataArray * xr.DataArray` broadcasting.
20+
- Stricter validation of `JaxDataArray` coordinates and values shape.
21+
1722

1823
## [2.4.1] - 2023-9-20
1924

tests/test_plugins/test_adjoint.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from tidy3d.web.container import BatchData
3636

3737
from ..utils import run_emulated, assert_log_level, log_capture, run_async_emulated
38+
from ..test_components.test_custom import CUSTOM_MEDIUM
3839

3940
TMP_PATH = None
4041
FWD_SIM_DATA_FILE = "adjoint_grad_data_fwd.hdf5"
@@ -766,6 +767,27 @@ def test_jax_data_array():
766767
da1d = JaxDataArray(values=[0.0, 1.0, 2.0, 3.0], coords=dict(x=[0, 1, 2, 3]))
767768
assert np.isclose(da1d.interp(x=0.5), 0.5)
768769

770+
# duplicate coordinates
771+
sel_a_coords = [1, 2, 3, 2, 1]
772+
res = da.sel(a=sel_a_coords)
773+
assert res.coords["a"] == sel_a_coords
774+
assert res.values.shape[0] == len(sel_a_coords)
775+
776+
a = [1, 2, 3]
777+
b = [2, 3, 4, 5]
778+
c = [4, 6]
779+
shape = (len(a), len(b), len(c))
780+
values = np.random.random(shape)
781+
coords = dict(a=a, b=b, c=c)
782+
da = JaxDataArray(values=values, coords=coords)
783+
da2 = da.sel(b=[3, 4])
784+
assert da2.shape == (3, 2, 2)
785+
786+
da3 = da.interp(b=np.array([3, 4]))
787+
assert da3.shape == (3, 2, 2)
788+
789+
assert da2 == da3
790+
769791

770792
def test_jax_sim_data(use_emulated_run):
771793
"""Test mechanics of the JaxSimulationData."""
@@ -1460,7 +1482,8 @@ def test_sim_data_plot_field(use_emulated_run):
14601482
assert len(ax.collections) == 1
14611483

14621484

1463-
def test_polyslab_structures(use_emulated_run):
1485+
def test_pytreedef_errors(use_emulated_run):
1486+
"""Fix errors that occur when jax doesnt know how to handle array types in aux_data."""
14641487

14651488
vertices = [(0, 0), (1, 0), (1, 1), (0, 1)]
14661489
polyslab = td.PolySlab(vertices=vertices, slab_bounds=(0, 1), axis=2)
@@ -1482,12 +1505,24 @@ def test_polyslab_structures(use_emulated_run):
14821505
medium=td.Medium(),
14831506
)
14841507

1508+
flux_mnt = td.FieldMonitor(
1509+
center=(0, 0, 0),
1510+
size=(1, 1, 0),
1511+
freqs=1e14 * np.array([1, 2, 3]), # this previously errored
1512+
name="flux",
1513+
)
1514+
14851515
VERTICES = np.array(
14861516
[[-1.5, -0.5, -0.5], [-0.5, -0.5, -0.5], [-1.5, 0.5, -0.5], [-1.5, -0.5, 0.5]]
14871517
)
14881518
FACES = np.array([[1, 2, 3], [0, 3, 2], [0, 1, 3], [0, 2, 1]])
14891519
STL_GEO = td.TriangleMesh.from_trimesh(trimesh.Trimesh(VERTICES, FACES))
14901520

1521+
custom_medium = td.Structure(
1522+
geometry=td.Box(size=(1, 1, 1)),
1523+
medium=CUSTOM_MEDIUM,
1524+
)
1525+
14911526
stl_struct = td.Structure(geometry=STL_GEO, medium=td.Medium())
14921527

14931528
mnt = td.ModeMonitor(size=(1, 1, 0), freqs=[1e14], name="test", mode_spec=td.ModeSpec())
@@ -1499,10 +1534,11 @@ def f(x):
14991534

15001535
sim = JaxSimulation(
15011536
size=(2.0, 2.0, 2.0),
1502-
structures=[ps, gg, ggg, gggg, stl_struct],
1537+
structures=[ps, gg, ggg, gggg, stl_struct, custom_medium],
15031538
input_structures=[js],
15041539
run_time=1e-12,
15051540
output_monitors=[mnt],
1541+
monitors=[flux_mnt],
15061542
grid_spec=td.GridSpec.uniform(dl=0.1),
15071543
boundary_spec=td.BoundarySpec.pml(x=False, y=False, z=False),
15081544
)

tests/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ def make_field_data(monitor: td.FieldMonitor) -> td.FieldData:
488488

489489
return td.FieldData(
490490
monitor=monitor,
491-
symmetry=simulation.symmetry,
491+
symmetry=(0, 0, 0),
492492
symmetry_center=simulation.center,
493493
grid_expanded=grid,
494494
**field_cmps,

tidy3d/components/data/monitor_data.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,8 @@ def _symmetry_update_dict(self) -> Dict:
186186

187187
# Interpolate. There generally shouldn't be values out of bounds except potentially
188188
# when handling modes, in which case they should be at the boundary and close to 0.
189-
scalar_data = scalar_data.sel({dim_name: coords_interp}, method="nearest")
189+
190+
scalar_data = scalar_data.sel(**{dim_name: coords_interp}, method="nearest")
190191
scalar_data = scalar_data.assign_coords({dim_name: coords})
191192

192193
# apply the symmetry eigenvalue (if defined) to the flipped values

tidy3d/plugins/adjoint/components/base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from typing import Tuple, List, Any, Callable
55
import json
66

7+
import numpy as np
8+
79
from jax.tree_util import tree_flatten as jax_tree_flatten
810
from jax.tree_util import tree_unflatten as jax_tree_unflatten
911

@@ -56,6 +58,12 @@ def fix_polyslab(geo_dict: dict) -> None:
5658
for _i, structure in enumerate(structures):
5759
geometry = structure["geometry"]
5860
fix_polyslab(geometry)
61+
monitors = aux_data["monitors"]
62+
for _i, monitor in enumerate(monitors):
63+
if "freqs" in monitor:
64+
freqs = monitor["freqs"]
65+
if isinstance(freqs, np.ndarray):
66+
monitor["freqs"] = freqs.tolist()
5967

6068
return children, aux_data
6169

tidy3d/plugins/adjoint/components/data/data_array.py

Lines changed: 62 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -40,51 +40,43 @@ class JaxDataArray(Tidy3dBaseModel):
4040
description="Dictionary storing the coordinates, namely ``(direction, f, mode_index)``.",
4141
)
4242

43-
@pd.validator("coords", always=True)
44-
def _convert_coords_to_list(cls, val):
45-
"""Convert supplied coordinates to Dict[str, list]."""
46-
return {coord_name: list(coord_list) for coord_name, coord_list in val.items()}
47-
4843
@pd.validator("values", always=True)
4944
def _convert_values_to_np(cls, val):
5045
"""Convert supplied values to numpy if they are list (from file)."""
5146
if isinstance(val, list):
5247
return np.array(val)
5348
return val
5449

55-
def __eq__(self, other) -> bool:
56-
"""Check if two ``JaxDataArray`` instances are equal."""
57-
return jnp.array_equal(self.values, other.values)
58-
59-
# removed because it was slowing things down.
60-
# @pd.validator("coords", always=True)
61-
# def _coords_match_values(cls, val, values):
62-
# """Make sure the coordinate dimensions and shapes match the values data."""
63-
64-
# values = values.get("values")
65-
66-
# # if values did not pass validation, just skip this validator
67-
# if values is None:
68-
# return None
69-
70-
# # compute the shape, otherwise exit.
71-
# try:
72-
# shape = jnp.array(values).shape
73-
# except TypeError:
74-
# return val
50+
@pd.validator("coords", always=True)
51+
def _coords_match_values(cls, val, values):
52+
"""Make sure the coordinate dimensions and shapes match the values data."""
53+
54+
_values = values.get("values")
55+
56+
# get the shape, handling both regular and jax objects
57+
try:
58+
values_shape = np.array(_values).shape
59+
except TypeError:
60+
values_shape = jnp.array(_values).shape
61+
62+
for (key, coord_val), size_dim in zip(val.items(), values_shape):
63+
if len(coord_val) != size_dim:
64+
raise ValueError(
65+
f"JaxDataArray coord {key} has {len(coord_val)} elements, "
66+
"which doesn't match the values array "
67+
f"with size {size_dim} along that dimension."
68+
)
7569

76-
# if len(shape) != len(val):
77-
# raise AdjointError(f"'values' has '{len(shape)}' dims, but given '{len(val)}'.")
70+
return val
7871

79-
# # make sure each coordinate list has same length as values along that axis
80-
# for len_dim, (coord_name, coord_list) in zip(shape, val.items()):
81-
# if len_dim != len(coord_list):
82-
# raise AdjointError(
83-
# f"coordinate '{coord_name}' has '{len(coord_list)}' elements, "
84-
# f"expected '{len_dim}' to match number of 'values' along this dimension."
85-
# )
72+
@pd.validator("coords", always=True)
73+
def _convert_coords_to_list(cls, val):
74+
"""Convert supplied coordinates to Dict[str, list]."""
75+
return {coord_name: list(coord_list) for coord_name, coord_list in val.items()}
8676

87-
# return val
77+
def __eq__(self, other) -> bool:
78+
"""Check if two ``JaxDataArray`` instances are equal."""
79+
return jnp.array_equal(self.values, other.values)
8880

8981
def to_hdf5(self, fname: str, group_path: str) -> None:
9082
"""Save an xr.DataArray to the hdf5 file with a given path to the group."""
@@ -198,11 +190,18 @@ def __mul__(self, other: JaxDataArray) -> JaxDataArray:
198190
new_values = self.as_jnp_array * other.as_jnp_array
199191
elif isinstance(other, xr.DataArray):
200192

201-
other_values = other.values.reshape(self.values.shape)
202-
new_values = self.as_jnp_array * other_values
193+
# handle case where other is missing dims present in self
194+
new_shape = list(self.shape)
195+
for dim_index, dim in enumerate(self.coords.keys()):
196+
if dim not in other.dims:
197+
other = other.expand_dims(dim=dim)
198+
new_shape[dim_index] = 1
203199

200+
other_values = other.values.reshape(new_shape)
201+
new_values = self.as_jnp_array * other_values
204202
else:
205203
new_values = self.as_jnp_array * other
204+
206205
return self.updated_copy(values=new_values)
207206

208207
def __rmul__(self, other) -> JaxDataArray:
@@ -265,8 +264,10 @@ def isel_single(self, coord_name: str, coord_index: int) -> JaxDataArray:
265264

266265
# if the coord index has more than one item, keep that coordinate
267266
coord_index = np.array(coord_index)
268-
if coord_index.size > 1:
269-
new_coords[coord_name] = coord_index.tolist()
267+
if len(coord_index.shape) >= 1:
268+
coord_indices = coord_index.tolist()
269+
new_coord_vals = [self.coords[coord_name][coord_index] for coord_index in coord_indices]
270+
new_coords[coord_name] = new_coord_vals
270271
else:
271272
new_coords.pop(coord_name)
272273

@@ -306,20 +307,36 @@ def sel(self, indexers: dict = None, method: str = "nearest", **sel_kwargs) -> J
306307
isel_kwargs = {}
307308
for coord_name, sel_kwarg in sel_kwargs.items():
308309
coord_list = self.get_coord_list(coord_name)
309-
if sel_kwarg not in coord_list:
310-
raise DataError(f"Could not select '{coord_name}={sel_kwarg}', value not found.")
311-
coord_index = coord_list.index(sel_kwarg)
312-
isel_kwargs[coord_name] = coord_index
310+
if isinstance(sel_kwarg, (tuple, list, np.ndarray)):
311+
sel_kwarg = list(sel_kwarg)
312+
isel_kwargs[coord_name] = []
313+
for _sel_kwarg in sel_kwarg:
314+
if _sel_kwarg not in coord_list:
315+
raise DataError(
316+
f"Could not select '{coord_name}={_sel_kwarg}', value not found."
317+
)
318+
coord_index = coord_list.index(_sel_kwarg)
319+
isel_kwargs[coord_name].append(coord_index)
320+
else:
321+
if sel_kwarg not in coord_list:
322+
raise DataError(
323+
f"Could not select '{coord_name}={sel_kwarg}', value not found."
324+
)
325+
coord_index = coord_list.index(sel_kwarg)
326+
isel_kwargs[coord_name] = coord_index
313327
return self.isel(**isel_kwargs)
314328

315329
def assign_coords(self, coords: dict = None, **coords_kwargs) -> JaxDataArray:
316330
"""Assign new coordinates to this object."""
317331

318332
update_kwargs = self.coords.copy()
319333

320-
update_kwargs.update(coords_kwargs)
334+
for key, val in coords_kwargs.items():
335+
update_kwargs[key] = val
336+
321337
if coords:
322-
update_kwargs.update(coords)
338+
for key, val in coords.items():
339+
update_kwargs[key] = val
323340

324341
update_kwargs = {key: np.array(value).tolist() for key, value in update_kwargs.items()}
325342
return self.updated_copy(coords=update_kwargs)

tidy3d/plugins/adjoint/components/data/sim_data.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,7 @@ def from_sim_data(
125125
mnt_data_type = JAX_MONITOR_DATA_MAP[mnt_data_type_str]
126126
jax_mnt_data = mnt_data_type.from_monitor_data(mnt_data)
127127
output_data_list.append(jax_mnt_data)
128-
data_dict["output_data"] = output_data_list
129-
128+
data_dict["output_data"] = output_data_list
130129
self_dict.update(data_dict)
131130
self_dict.update(dict(task_id=task_id))
132131

0 commit comments

Comments
 (0)