2424
2525
2626class DesignRegion :
27- def __init__ (self , design_parameters , volume = None , size = None , center = mp .Vector3 ()):
27+ def __init__ (
28+ self ,
29+ design_parameters : Iterable [onp .ndarray ],
30+ volume : mp .Volume = None ,
31+ size : mp .Vector3 = None ,
32+ center : mp .Vector3 = mp .Vector3 (),
33+ ):
2834 self .volume = volume or mp .Volume (center = center , size = size )
2935 self .size = self .volume .size
3036 self .center = self .volume .center
3137 self .design_parameters = design_parameters
3238 self .num_design_params = design_parameters .num_params
3339
34- def update_design_parameters (self , design_parameters ):
40+ def update_design_parameters (self , design_parameters ) -> None :
3541 self .design_parameters .update_weights (design_parameters )
3642
37- def update_beta (self , beta ) :
43+ def update_beta (self , beta : float ) -> None :
3844 self .design_parameters .beta = beta
3945
4046 def get_gradient (
41- self , sim , fields_a , fields_f , frequencies , finite_difference_step
42- ):
47+ self ,
48+ sim : mp .Simulation ,
49+ fields_a : List [mp .DftFields ],
50+ fields_f : List [mp .DftFields ],
51+ frequencies : List [float ],
52+ finite_difference_step : float ,
53+ ) -> onp .ndarray :
4354 num_freqs = onp .array (frequencies ).size
4455 """We have the option to linearly scale the gradients up front
4556 using the scalegrad parameter (leftover from MPB API). Not
@@ -67,11 +78,11 @@ def get_gradient(
6778 return onp .squeeze (grad ).T
6879
6980
70- def _check_if_cylindrical (sim ) :
81+ def _check_if_cylindrical (sim : mp . Simulation ) -> bool :
7182 return sim .is_cylindrical or (sim .dimensions == mp .CYLINDRICAL )
7283
7384
74- def _compute_components (sim ) :
85+ def _compute_components (sim : mp . Simulation ) -> List [ int ] :
7586 return (
7687 _ADJOINT_FIELD_COMPONENTS_CYL
7788 if _check_if_cylindrical (sim )
@@ -88,8 +99,8 @@ def calculate_vjps(
8899 simulation : mp .Simulation ,
89100 design_regions : List [DesignRegion ],
90101 frequencies : List [float ],
91- fwd_fields : List [List [onp . ndarray ]],
92- adj_fields : List [List [onp . ndarray ]],
102+ fwd_fields : List [List [mp . DftFields ]],
103+ adj_fields : List [List [mp . DftFields ]],
93104 design_variable_shapes : List [Tuple [int , ...]],
94105 sum_freq_partials : bool = True ,
95106 finite_difference_step : float = FD_DEFAULT ,
@@ -132,7 +143,7 @@ def install_design_region_monitors(
132143 design_regions : List [DesignRegion ],
133144 frequencies : List [float ],
134145 decimation_factor : int = 0 ,
135- ) -> List [mp .DftFields ]:
146+ ) -> List [List [ mp .DftFields ] ]:
136147 """Installs DFT field monitors at the design regions of the simulation."""
137148 return [
138149 [
@@ -168,41 +179,6 @@ def gather_monitor_values(monitors: List[ObjectiveQuantity]) -> onp.ndarray:
168179 return monitor_values
169180
170181
171- def gather_design_region_fields (
172- simulation : mp .Simulation ,
173- design_region_monitors : List [mp .DftFields ],
174- frequencies : List [float ],
175- ) -> List [List [onp .ndarray ]]:
176- """Collects the design region DFT fields from the simulation.
177-
178- Args:
179- simulation: the simulation object.
180- design_region_monitors: the installed design region monitors.
181- frequencies: the frequencies to monitor.
182-
183- Returns:
184- A list of lists. Each entry (list) in the overall list corresponds one-to-
185- one with a declared design region. For each such contained list, the
186- entries correspond to the field components that are monitored. The entries
187- are ndarrays of rank 4 with dimensions (freq, x, y, (z-or-pad)).
188-
189- The design region fields are sampled on the *Yee grid*. This makes them
190- fairly awkward to inspect directly. Their primary use case is supporting
191- gradient calculations.
192- """
193- design_region_fields = []
194- for monitor in design_region_monitors :
195- fields_by_component = []
196- for component in _compute_components (simulation ):
197- fields_by_freq = []
198- for freq_idx , _ in enumerate (frequencies ):
199- fields = simulation .get_dft_array (monitor , component , freq_idx )
200- fields_by_freq .append (_make_at_least_nd (fields ))
201- fields_by_component .append (onp .stack (fields_by_freq ))
202- design_region_fields .append (fields_by_component )
203- return design_region_fields
204-
205-
206182def validate_and_update_design (
207183 design_regions : List [DesignRegion ], design_variables : Iterable [onp .ndarray ]
208184) -> None :
0 commit comments