Skip to content

Commit 6592c97

Browse files
tylerflexmomchil-flex
authored andcommitted
adjoint refactor with separate jax fields
1 parent 996ec08 commit 6592c97

File tree

11 files changed

+311
-380
lines changed

11 files changed

+311
-380
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2222

2323
### Changed
2424
- All solver output is now compressed. However, it is automatically unpacked to the same `simulation_data.hdf5` by default when loading simulation data from the server.
25+
- Internal refactor of `adjoint` plugin to separate `jax`-traced fields from regular `tidy3d` fields.
2526

2627
### Fixed
2728

tests/test_plugins/test_adjoint.py

Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Tests adjoint plugin."""
22

33
from typing import Tuple, Dict, List
4+
import builtins
45

56
import pytest
67
import pydantic.v1 as pydantic
@@ -471,6 +472,8 @@ def get_flux(x):
471472
def test_adjoint_pipeline(local, use_emulated_run, tmp_path):
472473
"""Test computing gradient using jax."""
473474

475+
td.config.logging_level = "ERROR"
476+
474477
run_fn = run_local if local else run
475478

476479
sim = make_sim(permittivity=EPS, size=SIZE, vertices=VERTICES, base_eps_val=BASE_EPS_VAL)
@@ -487,6 +490,22 @@ def f(permittivity, size, vertices, base_eps_val):
487490
grad_f = grad(f, argnums=(0, 1, 2, 3))
488491
df_deps, df_dsize, df_dvertices, d_eps_base = grad_f(EPS, SIZE, VERTICES, BASE_EPS_VAL)
489492

493+
# fail if all gradients close to zero
494+
assert not all(
495+
np.all(np.isclose(x, 0)) for x in (df_deps, df_dsize, df_dvertices, d_eps_base)
496+
), "No gradients registered"
497+
498+
# fail if any gradients close to zero
499+
assert not any(
500+
np.any(np.isclose(x, 0)) for x in (df_deps, df_dsize, df_dvertices, d_eps_base)
501+
), "Some of the gradients are zero unexpectedly."
502+
503+
# fail if some gradients dont match the pre/2.6 grads (before refactor).
504+
if local:
505+
assert np.isclose(df_deps, 1278130200000000.0), "local grad doesn't match previous value."
506+
else:
507+
assert np.isclose(df_deps, 0.031742122), "non-local grad doesn't match previous value."
508+
490509
print("gradient: ", df_deps, df_dsize, df_dvertices, d_eps_base)
491510

492511

@@ -1328,11 +1347,18 @@ def f(scale=1.0):
13281347

13291348
def test_validate_vertices():
13301349
"""Test the maximum number of vertices."""
1331-
vertices = np.random.rand(MAX_NUM_VERTICES, 2)
1332-
_ = JaxPolySlab(vertices=vertices, slab_bounds=(-1, 1))
1333-
vertices = np.random.rand(MAX_NUM_VERTICES + 1, 2)
1350+
1351+
def make_vertices(n: int) -> np.ndarray:
1352+
"""Make circular polygon vertices of shape (n, 2)."""
1353+
angles = np.linspace(0, 2 * np.pi, n)
1354+
return np.stack((np.cos(angles), np.sin(angles)), axis=-1)
1355+
1356+
vertices_pass = make_vertices(MAX_NUM_VERTICES)
1357+
_ = JaxPolySlab(vertices=vertices_pass, slab_bounds=(-1, 1))
1358+
13341359
with pytest.raises(pydantic.ValidationError):
1335-
_ = JaxPolySlab(vertices=vertices, slab_bounds=(-1, 1))
1360+
vertices_fail = make_vertices(MAX_NUM_VERTICES + 1)
1361+
_ = JaxPolySlab(vertices=vertices_fail, slab_bounds=(-1, 1))
13361362

13371363

13381364
def _test_custom_medium_3D(use_emulated_run):
@@ -1695,3 +1721,46 @@ def test_nonlinear_warn(log_capture):
16951721
# nonlinear input_structure (warn)
16961722
with AssertLogLevel(log_capture, "WARNING"):
16971723
sim = sim_base.updated_copy(input_structures=[input_struct_nl])
1724+
1725+
1726+
@pytest.fixture
1727+
def hide_jax(monkeypatch, request):
1728+
import_orig = builtins.__import__
1729+
1730+
def mocked_import(name, *args, **kwargs):
1731+
if name in ["jax", "jax.interpreters.ad", "jax.interpreters.ad.JVPTracer"]:
1732+
raise ImportError()
1733+
return import_orig(name, *args, **kwargs)
1734+
1735+
monkeypatch.setattr(builtins, "__import__", mocked_import)
1736+
1737+
1738+
def try_tracer_import() -> None:
1739+
"""Try importing `tidy3d.plugins.adjoint.components.types`."""
1740+
from importlib import reload
1741+
import tidy3d
1742+
1743+
reload(tidy3d.plugins.adjoint.components.types)
1744+
1745+
1746+
@pytest.mark.usefixtures("hide_jax")
1747+
def test_jax_tracer_import_fail(tmp_path, log_capture):
1748+
"""Make sure if import error with JVPTracer, a warning is logged and module still imports."""
1749+
try_tracer_import()
1750+
assert_log_level(log_capture, "WARNING")
1751+
1752+
1753+
def test_jax_tracer_import_pass(tmp_path, log_capture):
1754+
"""Make sure if no import error with JVPTracer, nothing is logged and module imports."""
1755+
try_tracer_import()
1756+
assert_log_level(log_capture, None)
1757+
1758+
1759+
def test_inf_IO(tmp_path):
1760+
"""test that components can save and load "Infinity" properly in jax fields."""
1761+
fname = str(tmp_path / "box.json")
1762+
1763+
box = JaxBox(size=(td.inf, td.inf, td.inf), center=(0, 0, 0))
1764+
box.to_file(fname)
1765+
box2 = JaxBox.from_file(fname)
1766+
assert box == box2

tidy3d/components/geometry/polyslab.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,11 @@ def _normal_2dmaterial(self) -> Axis:
427427
raise ValidationError("'Medium2D' requires the 'PolySlab' bounds to be equal.")
428428
return self.axis
429429

430+
@cached_property
431+
def is_ccw(self) -> bool:
432+
"""Is this ``PolySlab`` CCW-oriented?"""
433+
return PolySlab._area(self.vertices) > 0
434+
430435
def inside(
431436
self, x: np.ndarray[float], y: np.ndarray[float], z: np.ndarray[float]
432437
) -> np.ndarray[bool]:

tidy3d/plugins/adjoint/components/base.py

Lines changed: 139 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import json
66

77
import numpy as np
8+
import jax
9+
import pydantic.v1 as pd
810

911
from jax.tree_util import tree_flatten as jax_tree_flatten
1012
from jax.tree_util import tree_unflatten as jax_tree_unflatten
@@ -16,67 +18,78 @@
1618
class JaxObject(Tidy3dBaseModel):
1719
"""Abstract class that makes a :class:`.Tidy3dBaseModel` jax-compatible through inheritance."""
1820

19-
"""Shortcut to get names of all fields that have jax components."""
21+
_tidy3d_class = Tidy3dBaseModel
22+
23+
"""Shortcut to get names of fields with certain properties."""
2024

2125
@classmethod
22-
def get_jax_field_names(cls) -> List[str]:
23-
"""Returns list of field names that have a ``jax_field_type``."""
24-
adjoint_fields = []
26+
def _get_field_names(cls, field_key: str) -> List[str]:
27+
"""Get all fields where ``field_key`` defined in the ``pydantic.Field``."""
28+
fields = []
2529
for field_name, model_field in cls.__fields__.items():
26-
jax_field_type = model_field.field_info.extra.get("jax_field")
27-
if jax_field_type:
28-
adjoint_fields.append(field_name)
29-
return adjoint_fields
30+
field_value = model_field.field_info.extra.get(field_key)
31+
if field_value:
32+
fields.append(field_name)
33+
return fields
34+
35+
@classmethod
36+
def get_jax_field_names(cls) -> List[str]:
37+
"""Returns list of field names where ``jax_field=True``."""
38+
return cls._get_field_names("jax_field")
39+
40+
@classmethod
41+
def get_jax_leaf_names(cls) -> List[str]:
42+
"""Returns list of field names where ``stores_jax_for`` defined."""
43+
return cls._get_field_names("stores_jax_for")
44+
45+
@classmethod
46+
def get_jax_field_names_all(cls) -> List[str]:
47+
"""Returns list of field names where ``jax_field=True`` or ``stores_jax_for`` defined."""
48+
jax_field_names = cls.get_jax_field_names()
49+
jax_leaf_names = cls.get_jax_leaf_names()
50+
return list(set(jax_field_names + jax_leaf_names))
51+
52+
@property
53+
def jax_fields(self) -> dict:
54+
"""Get dictionary of ``jax`` fields."""
55+
56+
# TODO: don't use getattr, define this dictionary better
57+
jax_field_names = self.get_jax_field_names()
58+
return {key: getattr(self, key) for key in jax_field_names}
3059

3160
"""Methods needed for jax to register arbitrary classes."""
3261

3362
def tree_flatten(self) -> Tuple[list, dict]:
34-
"""How to flatten a :class:`.JaxObject` instance into a pytree."""
63+
"""How to flatten a :class:`.JaxObject` instance into a ``pytree``."""
3564
children = []
3665
aux_data = self.dict()
37-
for field_name in self.get_jax_field_names():
66+
67+
for field_name in self.get_jax_field_names_all():
3868
field = getattr(self, field_name)
3969
sub_children, sub_aux_data = jax_tree_flatten(field)
4070
children.append(sub_children)
4171
aux_data[field_name] = sub_aux_data
4272

43-
def fix_polyslab(geo_dict: dict) -> None:
44-
"""Recursively Fix a dictionary possibly containing a polyslab geometry."""
45-
if geo_dict["type"] == "PolySlab":
46-
vertices = geo_dict["vertices"]
47-
geo_dict["vertices"] = vertices.tolist()
48-
elif geo_dict["type"] == "GeometryGroup":
49-
for sub_geo_dict in geo_dict["geometries"]:
50-
fix_polyslab(sub_geo_dict)
51-
elif geo_dict["type"] == "ClipOperation":
52-
fix_polyslab(geo_dict["geometry_a"])
53-
fix_polyslab(geo_dict["geometry_b"])
54-
55-
def fix_monitor(mnt_dict: dict) -> None:
56-
"""Fix a frequency containing monitor."""
57-
if "freqs" in mnt_dict:
58-
freqs = mnt_dict["freqs"]
59-
if isinstance(freqs, np.ndarray):
60-
mnt_dict["freqs"] = freqs.tolist()
61-
62-
# fixes bug with jax handling 2D numpy array in polyslab vertices
63-
if aux_data.get("type", "") == "JaxSimulation":
64-
structures = aux_data["structures"]
65-
for _i, structure in enumerate(structures):
66-
geometry = structure["geometry"]
67-
fix_polyslab(geometry)
68-
for monitor in aux_data["monitors"]:
69-
fix_monitor(monitor)
70-
for monitor in aux_data["output_monitors"]:
71-
fix_monitor(monitor)
73+
def fix_numpy(value: Any) -> Any:
74+
"""Recursively convert any ``numpy`` array in the value to nested list."""
75+
if isinstance(value, (tuple, list)):
76+
return [fix_numpy(val) for val in value]
77+
if isinstance(value, np.ndarray):
78+
return value.tolist()
79+
if isinstance(value, dict):
80+
return {key: fix_numpy(val) for key, val in value.items()}
81+
else:
82+
return value
83+
84+
aux_data = fix_numpy(aux_data)
7285

7386
return children, aux_data
7487

7588
@classmethod
7689
def tree_unflatten(cls, aux_data: dict, children: list) -> JaxObject:
77-
"""How to unflatten a pytree into a :class:`.JaxObject` instance."""
90+
"""How to unflatten a ``pytree`` into a :class:`.JaxObject` instance."""
7891
self_dict = aux_data.copy()
79-
for field_name, sub_children in zip(cls.get_jax_field_names(), children):
92+
for field_name, sub_children in zip(cls.get_jax_field_names_all(), children):
8093
sub_aux_data = aux_data[field_name]
8194
field = jax_tree_unflatten(sub_aux_data, sub_children)
8295
self_dict[field_name] = field
@@ -85,38 +98,110 @@ def tree_unflatten(cls, aux_data: dict, children: list) -> JaxObject:
8598

8699
"""Type conversion helpers."""
87100

101+
def to_tidy3d(self: JaxObject) -> Tidy3dBaseModel:
102+
"""Convert :class:`.JaxObject` instance to :class:`.Tidy3dBaseModel` instance."""
103+
104+
self_dict = self.dict(exclude=self.exclude_fields_leafs_only)
105+
106+
for key in self.get_jax_field_names():
107+
sub_field = self.jax_fields[key]
108+
109+
# TODO: simplify this logic
110+
if isinstance(sub_field, (tuple, list)):
111+
self_dict[key] = [x.to_tidy3d() for x in sub_field]
112+
else:
113+
self_dict[key] = sub_field.to_tidy3d()
114+
# end TODO
115+
116+
return self._tidy3d_class.parse_obj(self_dict)
117+
88118
@classmethod
89119
def from_tidy3d(cls, tidy3d_obj: Tidy3dBaseModel) -> JaxObject:
90120
"""Convert :class:`.Tidy3dBaseModel` instance to :class:`.JaxObject`."""
91121
obj_dict = tidy3d_obj.dict(exclude={"type"})
122+
123+
for key in cls.get_jax_field_names():
124+
sub_field_type = cls.__fields__[key].type_
125+
tidy3d_sub_field = getattr(tidy3d_obj, key)
126+
127+
# TODO: simplify this logic
128+
if isinstance(tidy3d_sub_field, (tuple, list)):
129+
obj_dict[key] = [sub_field_type.from_tidy3d(x) for x in tidy3d_sub_field]
130+
else:
131+
obj_dict[key] = sub_field_type.from_tidy3d(tidy3d_sub_field)
132+
# end TODO
133+
92134
return cls.parse_obj(obj_dict)
93135

136+
@property
137+
def exclude_fields_leafs_only(self) -> set:
138+
"""Fields to exclude from ``self.dict()``, ``"type"`` and all ``jax`` leafs."""
139+
return set(["type"] + self.get_jax_leaf_names())
140+
141+
"""Accounting with jax and regular fields."""
142+
143+
@pd.root_validator(pre=True)
144+
def handle_jax_kwargs(cls, values: dict) -> dict:
145+
"""Pass jax inputs to the jax fields and pass untraced values to the regular fields."""
146+
147+
# for all jax-traced fields
148+
for jax_name in cls.get_jax_leaf_names():
149+
# if a value was passed to the object for the regular field
150+
orig_name = cls.__fields__[jax_name].field_info.extra.get("stores_jax_for")
151+
val = values.get(orig_name)
152+
if val is not None:
153+
154+
# try adding the sanitized (no trace) version to the regular field
155+
try:
156+
values[orig_name] = jax.lax.stop_gradient(val)
157+
158+
# if it doesnt work, just pass the raw value (necessary to handle inf strings)
159+
except TypeError:
160+
values[orig_name] = val
161+
162+
# if the jax name was not specified directly, use the original traced value
163+
if jax_name not in values:
164+
values[jax_name] = val
165+
166+
return values
167+
168+
@pd.root_validator(pre=True)
169+
def handle_array_jax_leafs(cls, values: dict) -> dict:
170+
"""Convert jax_leafs that are passed as numpy arrays."""
171+
for jax_name in cls.get_jax_leaf_names():
172+
val = values.get(jax_name)
173+
if isinstance(val, np.ndarray):
174+
values[jax_name] = val.tolist()
175+
return values
176+
94177
""" IO """
95178

179+
# TODO: replace with JaxObject json encoder
180+
96181
def _json(self, *args, **kwargs) -> str:
97182
"""Overwritten method to get the json string to store in the files."""
98183

99184
json_string_og = super()._json(*args, **kwargs)
100185
json_dict = json.loads(json_string_og)
101186

102-
def strip_data_array(sub_dict: dict) -> None:
103-
"""Strip any elements of the dictionary with type "JaxDataArray", replace with tag."""
187+
def strip_data_array(val: Any) -> Any:
188+
"""Recursively strip any elements with type "JaxDataArray", replace with tag."""
104189

105-
for key, val in sub_dict.items():
190+
if isinstance(val, dict):
191+
if "type" in val and val["type"] == "JaxDataArray":
192+
return JAX_DATA_ARRAY_TAG
193+
return {k: strip_data_array(v) for k, v in val.items()}
106194

107-
if isinstance(val, dict):
108-
if "type" in val and val["type"] == "JaxDataArray":
109-
sub_dict[key] = JAX_DATA_ARRAY_TAG
110-
else:
111-
strip_data_array(val)
112-
elif isinstance(val, (list, tuple)):
113-
val_dict = dict(zip(range(len(val)), val))
114-
strip_data_array(val_dict)
115-
sub_dict[key] = list(val_dict.values())
195+
elif isinstance(val, (tuple, list)):
196+
return [strip_data_array(v) for v in val]
116197

117-
strip_data_array(json_dict)
198+
return val
199+
200+
json_dict = strip_data_array(json_dict)
118201
return json.dumps(json_dict)
119202

203+
# TODO: replace with implementing these in DataArray
204+
120205
def to_hdf5(self, fname: str, custom_encoders: List[Callable] = None) -> None:
121206
"""Exports :class:`JaxObject` instance to .hdf5 file.
122207

0 commit comments

Comments
 (0)