Skip to content

Commit ea43f83

Browse files
authored
add/update type hints for JAX wrapper of adjoint solver (#2190)
* add/update type hints for JAX wrapper of adjoint solver * fix return type of install_design_region_monitors
1 parent 2aa9164 commit ea43f83

File tree

3 files changed

+43
-59
lines changed

3 files changed

+43
-59
lines changed

python/adjoint/utils.py

Lines changed: 21 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,33 @@
2424

2525

2626
class DesignRegion:
27-
def __init__(self, design_parameters, volume=None, size=None, center=mp.Vector3()):
27+
def __init__(
28+
self,
29+
design_parameters: Iterable[onp.ndarray],
30+
volume: mp.Volume = None,
31+
size: mp.Vector3 = None,
32+
center: mp.Vector3 = mp.Vector3(),
33+
):
2834
self.volume = volume or mp.Volume(center=center, size=size)
2935
self.size = self.volume.size
3036
self.center = self.volume.center
3137
self.design_parameters = design_parameters
3238
self.num_design_params = design_parameters.num_params
3339

34-
def update_design_parameters(self, design_parameters):
40+
def update_design_parameters(self, design_parameters) -> None:
3541
self.design_parameters.update_weights(design_parameters)
3642

37-
def update_beta(self, beta):
43+
def update_beta(self, beta: float) -> None:
3844
self.design_parameters.beta = beta
3945

4046
def get_gradient(
41-
self, sim, fields_a, fields_f, frequencies, finite_difference_step
42-
):
47+
self,
48+
sim: mp.Simulation,
49+
fields_a: List[mp.DftFields],
50+
fields_f: List[mp.DftFields],
51+
frequencies: List[float],
52+
finite_difference_step: float,
53+
) -> onp.ndarray:
4354
num_freqs = onp.array(frequencies).size
4455
"""We have the option to linearly scale the gradients up front
4556
using the scalegrad parameter (leftover from MPB API). Not
@@ -67,11 +78,11 @@ def get_gradient(
6778
return onp.squeeze(grad).T
6879

6980

70-
def _check_if_cylindrical(sim):
81+
def _check_if_cylindrical(sim: mp.Simulation) -> bool:
7182
return sim.is_cylindrical or (sim.dimensions == mp.CYLINDRICAL)
7283

7384

74-
def _compute_components(sim):
85+
def _compute_components(sim: mp.Simulation) -> List[int]:
7586
return (
7687
_ADJOINT_FIELD_COMPONENTS_CYL
7788
if _check_if_cylindrical(sim)
@@ -88,8 +99,8 @@ def calculate_vjps(
8899
simulation: mp.Simulation,
89100
design_regions: List[DesignRegion],
90101
frequencies: List[float],
91-
fwd_fields: List[List[onp.ndarray]],
92-
adj_fields: List[List[onp.ndarray]],
102+
fwd_fields: List[List[mp.DftFields]],
103+
adj_fields: List[List[mp.DftFields]],
93104
design_variable_shapes: List[Tuple[int, ...]],
94105
sum_freq_partials: bool = True,
95106
finite_difference_step: float = FD_DEFAULT,
@@ -132,7 +143,7 @@ def install_design_region_monitors(
132143
design_regions: List[DesignRegion],
133144
frequencies: List[float],
134145
decimation_factor: int = 0,
135-
) -> List[mp.DftFields]:
146+
) -> List[List[mp.DftFields]]:
136147
"""Installs DFT field monitors at the design regions of the simulation."""
137148
return [
138149
[
@@ -168,41 +179,6 @@ def gather_monitor_values(monitors: List[ObjectiveQuantity]) -> onp.ndarray:
168179
return monitor_values
169180

170181

171-
def gather_design_region_fields(
172-
simulation: mp.Simulation,
173-
design_region_monitors: List[mp.DftFields],
174-
frequencies: List[float],
175-
) -> List[List[onp.ndarray]]:
176-
"""Collects the design region DFT fields from the simulation.
177-
178-
Args:
179-
simulation: the simulation object.
180-
design_region_monitors: the installed design region monitors.
181-
frequencies: the frequencies to monitor.
182-
183-
Returns:
184-
A list of lists. Each entry (list) in the overall list corresponds one-to-
185-
one with a declared design region. For each such contained list, the
186-
entries correspond to the field components that are monitored. The entries
187-
are ndarrays of rank 4 with dimensions (freq, x, y, (z-or-pad)).
188-
189-
The design region fields are sampled on the *Yee grid*. This makes them
190-
fairly awkward to inspect directly. Their primary use case is supporting
191-
gradient calculations.
192-
"""
193-
design_region_fields = []
194-
for monitor in design_region_monitors:
195-
fields_by_component = []
196-
for component in _compute_components(simulation):
197-
fields_by_freq = []
198-
for freq_idx, _ in enumerate(frequencies):
199-
fields = simulation.get_dft_array(monitor, component, freq_idx)
200-
fields_by_freq.append(_make_at_least_nd(fields))
201-
fields_by_component.append(onp.stack(fields_by_freq))
202-
design_region_fields.append(fields_by_component)
203-
return design_region_fields
204-
205-
206182
def validate_and_update_design(
207183
design_regions: List[DesignRegion], design_variables: Iterable[onp.ndarray]
208184
) -> None:

python/adjoint/wrapper.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def loss(x):
4646
value, grad = jax.value_and_grad(loss)(x)
4747
```
4848
"""
49-
from typing import Callable, List, Tuple
49+
from typing import Callable, Iterable, List, Tuple
5050

5151
import jax
5252
import jax.numpy as jnp
@@ -137,7 +137,9 @@ def __call__(self, designs: List[jnp.ndarray]) -> jnp.ndarray:
137137
"""
138138
return self._simulate_fn(designs)
139139

140-
def _run_fwd_simulation(self, design_variables):
140+
def _run_fwd_simulation(
141+
self, design_variables: Iterable[onp.ndarray]
142+
) -> (jnp.ndarray, List[List[mp.DftFields]]):
141143
"""Runs forward simulation, returning monitor values and design region fields."""
142144
utils.validate_and_update_design(self.design_regions, design_variables)
143145
self.simulation.reset_meep()
@@ -161,7 +163,9 @@ def _run_fwd_simulation(self, design_variables):
161163
monitor_values = utils.gather_monitor_values(self.monitors)
162164
return (jnp.asarray(monitor_values), fwd_design_region_monitors)
163165

164-
def _run_adjoint_simulation(self, monitor_values_grad):
166+
def _run_adjoint_simulation(
167+
self, monitor_values_grad: onp.ndarray
168+
) -> List[List[mp.DftFields]]:
165169
"""Runs adjoint simulation, returning design region fields."""
166170
if not self.design_regions:
167171
raise RuntimeError(
@@ -195,11 +199,11 @@ def _run_adjoint_simulation(self, monitor_values_grad):
195199

196200
def _calculate_vjps(
197201
self,
198-
fwd_fields,
199-
adj_fields,
200-
design_variable_shapes,
201-
sum_freq_partials=True,
202-
):
202+
fwd_fields: List[List[mp.DftFields]],
203+
adj_fields: List[List[mp.DftFields]],
204+
design_variable_shapes: List[Tuple[int, ...]],
205+
sum_freq_partials: bool = True,
206+
) -> List[onp.ndarray]:
203207
"""Calculates the VJP for a given set of forward and adjoint fields."""
204208
return utils.calculate_vjps(
205209
self.simulation,

python/tests/test_adjoint_jax.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,20 @@
99

1010
import meep as mp
1111

12-
# The calculation of finite difference gradients requires that JAX be operated with double precision
12+
# The calculation of finite-difference gradients
13+
# requires that JAX be operated with double precision
1314
jax.config.update("jax_enable_x64", True)
1415

15-
# The step size for the finite difference gradient calculation
16+
# The step size for the finite-difference
17+
# gradient calculation
1618
_FD_STEP = 1e-4
1719

18-
# The tolerance for the adjoint and finite difference gradient comparison
20+
# The tolerance for the adjoint and finite-difference
21+
# gradient comparison
1922
_TOL = 0.1 if mp.is_single_precision() else 0.025
2023

21-
# We expect 3 design region monitor pointers (one for each field component)
24+
# We expect 3 design region monitor pointers
25+
# (one for each field component)
2226
_NUM_DES_REG_MON = 3
2327

2428
mp.verbosity(0)
@@ -257,8 +261,8 @@ def loss_fn(x, excite_port_idx=0):
257261
frequencies,
258262
)
259263
monitor_values = wrapped_meep([x])
260-
s1p, s1m, s2m, s2p = monitor_values
261-
t = s2m / s1p if excite_port_idx == 0 else s1m / s2p
264+
s1p, s1m, s2p, s2m = monitor_values
265+
t = s2p / s1p if excite_port_idx == 0 else s1m / s2m
262266
return jnp.mean(jnp.square(jnp.abs(t)))
263267

264268
value, adjoint_grad = jax.value_and_grad(loss_fn)(

0 commit comments

Comments
 (0)