Skip to content

Commit 2c8988d

Browse files
Gregory RobertsGregory Roberts
authored andcommitted
code review and bugbot
1 parent 977496e commit 2c8988d

File tree

3 files changed

+47
-34
lines changed

3 files changed

+47
-34
lines changed

tests/test_components/autograd/test_autograd.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1386,7 +1386,10 @@ def polyslab_custom_vjp_bad_arg_name(polyslab, d_info):
13861386

13871387
with pytest.raises(
13881388
td.exceptions.AdjointError,
1389-
match="CustomVJPConfig compute_derivatives function should accept two arguments and it currently accepts 3 arguments.",
1389+
match=(
1390+
"CustomVJPConfig compute_derivatives function should accept two arguments "
1391+
r"\(target, derivative_info\), and it currently accepts 3 arguments\."
1392+
),
13901393
):
13911394
CustomVJPConfig(
13921395
structure=td.PolySlab,

tidy3d/web/api/autograd/autograd.py

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -105,23 +105,6 @@ def insert_numerical_structures_static(
105105
return updated_simulation
106106

107107

108-
def _normalize_simulations_input(
109-
simulations: Union[dict[str, td.Simulation], tuple[td.Simulation], list[td.Simulation]],
110-
) -> dict[str, td.Simulation]:
111-
"""Normalize simulations to a dict keyed by task name."""
112-
113-
if isinstance(simulations, dict):
114-
return simulations
115-
116-
normalized: dict[str, td.Simulation] = {}
117-
118-
for idx, sim in enumerate(simulations):
119-
task_name = Tidy3dStub(simulation=sim).get_default_task_name() + f"_{idx + 1}"
120-
normalized[task_name] = sim
121-
122-
return normalized
123-
124-
125108
def has_traced_numerical_structures(
126109
numerical_structures: Union[
127110
tuple[NumericalStructureConfig, ...],
@@ -137,6 +120,10 @@ def has_traced_numerical_structures(
137120
else numerical_structures
138121
)
139122
for cfg in iterable_structures:
123+
if not isinstance(cfg, NumericalStructureConfig):
124+
raise AdjointError(
125+
"Entries in 'numerical_structures' must be NumericalStructureConfig instances."
126+
)
140127
parameters = cfg.parameters
141128
if hasbox(parameters):
142129
return True
@@ -746,6 +733,13 @@ def _expand_spec(
746733
validate_numerical_structure_parameters(
747734
numerical_structures=numerical_structures_configs
748735
)
736+
if any(numerical_structures.values()) and not all(
737+
isinstance(sim, td.Simulation) for sim in sim_dict.values()
738+
):
739+
raise AdjointError(
740+
"numerical_structures is only supported for 'Simulation' workflows in "
741+
"run_async_custom."
742+
)
749743

750744
custom_vjp = _expand_spec(
751745
fn_arg=custom_vjp,
@@ -755,18 +749,14 @@ def _expand_spec(
755749
arg_name="custom_vjp",
756750
)
757751

758-
simulations = sim_dict
759-
760752
path_dir = Path(path_dir)
761753

762-
simulations_norm = _normalize_simulations_input(simulations)
763-
764754
traced_numerical_structures = bool(numerical_structures) and any(
765755
has_traced_numerical_structures(numerical_structure)
766756
for _, numerical_structure in numerical_structures.items()
767757
)
768758
should_use_autograd_async = (
769-
is_valid_for_autograd_async(simulations_norm) or traced_numerical_structures
759+
is_valid_for_autograd_async(sim_dict) or traced_numerical_structures
770760
)
771761

772762
if should_use_autograd_async:
@@ -782,13 +772,13 @@ def _expand_spec(
782772
expanded_custom_vjp_dict = {}
783773
for sim_key, custom_vjp_entry in custom_vjp.items():
784774
expanded_custom_vjp_dict[sim_key] = expand_custom_vjp(
785-
custom_vjp_entry, simulations_norm[sim_key]
775+
custom_vjp_entry, sim_dict[sim_key]
786776
)
787777
else:
788778
expanded_custom_vjp_dict = None
789779

790780
return _run_async(
791-
simulations=simulations_norm,
781+
simulations=sim_dict,
792782
folder_name=folder_name,
793783
path_dir=path_dir,
794784
callback_url=callback_url,
@@ -811,16 +801,16 @@ def _expand_spec(
811801
simulations_static = {
812802
name: (
813803
insert_numerical_structures_static(
814-
simulation=simulations_norm[name],
804+
simulation=sim_dict[name],
815805
numerical_structures=numerical_structures[name],
816806
)
817807
if numerical_structures[name]
818-
else simulations_norm[name]
808+
else sim_dict[name]
819809
)
820-
for name in simulations_norm
810+
for name in sim_dict
821811
}
822812
else:
823-
simulations_static = simulations_norm
813+
simulations_static = sim_dict
824814

825815
return asynchronous_webapi.run_async(
826816
simulations=simulations_static,

tidy3d/web/api/autograd/types.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,20 @@
33
import inspect
44
from collections.abc import Sequence
55
from dataclasses import dataclass
6-
from typing import TYPE_CHECKING, Callable, NamedTuple, Optional, Union
6+
from typing import TYPE_CHECKING, Any, Callable, NamedTuple, Optional, Union
77

88
import numpy as np
99

1010
from tidy3d.exceptions import AdjointError
1111

1212
if TYPE_CHECKING:
13+
from tidy3d.components.autograd.derivative_utils import DerivativeInfo
14+
from tidy3d.components.autograd.types import PathType
1315
from tidy3d.components.autograd import AutogradFieldMap
1416
from tidy3d.components.geometry.utils import GeometryType
1517
from tidy3d.components.medium import MediumType
1618
from tidy3d.components.simulation import Simulation
19+
from tidy3d.components.structure import Structure
1720
from tidy3d.components.types import ArrayLike
1821

1922

@@ -45,10 +48,10 @@ def cylinder_vjp(parameters, derivative_info):
4548
)
4649
"""
4750

48-
create: Callable
51+
create: Callable[[ArrayLike], Structure]
4952
"""Function that creates the structure from static ``parameters``."""
5053

51-
compute_derivatives: Callable
54+
compute_derivatives: Callable[[ArrayLike, DerivativeInfo], dict[PathType, Any]]
5255
"""Function that computes numerical gradients for ``("numerical", index, param_i)`` paths.
5356
Signature: ``compute_derivatives(parameters, derivative_info) -> dict[path, gradient]``.
5457
"""
@@ -57,11 +60,18 @@ def cylinder_vjp(parameters, derivative_info):
5760
"""1D parameter vector consumed by ``create`` and ``compute_derivatives``."""
5861

5962
def __post_init__(self) -> None:
63+
self._validate_callables()
64+
self._validate_create_signature()
65+
self._validate_compute_derivatives_signature()
66+
self._validate_parameters()
67+
68+
def _validate_callables(self) -> None:
6069
if not callable(self.create):
6170
raise AdjointError("NumericalStructureConfig.create must be callable.")
6271
if not callable(self.compute_derivatives):
6372
raise AdjointError("NumericalStructureConfig.compute_derivatives must be callable.")
6473

74+
def _validate_create_signature(self) -> None:
6575
create_sig = inspect.signature(self.create)
6676
create_arg_names = list(create_sig.parameters.keys())
6777
if len(create_arg_names) != 1:
@@ -71,6 +81,7 @@ def __post_init__(self) -> None:
7181
f"accepts {len(create_arg_names)} arguments."
7282
)
7383

84+
def _validate_compute_derivatives_signature(self) -> None:
7485
vjp_sig = inspect.signature(self.compute_derivatives)
7586
vjp_arg_names = list(vjp_sig.parameters.keys())
7687
if len(vjp_arg_names) != 2:
@@ -87,6 +98,7 @@ def __post_init__(self) -> None:
8798
f"{vjp_arg_names[1]} but it should be derivative_info."
8899
)
89100

101+
def _validate_parameters(self) -> None:
90102
try:
91103
array_params = np.asarray(self.parameters)
92104
except Exception as exc:
@@ -126,7 +138,7 @@ def polyslab_vjp(polyslab, derivative_info):
126138
Can be an index or a geometry/medium type (expanded to matching indices).
127139
"""
128140

129-
compute_derivatives: Callable
141+
compute_derivatives: Callable[[GeometryType | MediumType, DerivativeInfo], dict[PathType, Any]]
130142
"""Function for computing the targeted vjp value. The function should accept the geometry or medium in the
131143
structure depending on if this is a geometry or medium path (see path_key) as the first argument. The second
132144
argument should accept a DerivativeInfo object that contains important for computing the gradient. The function
@@ -141,15 +153,23 @@ def polyslab_vjp(polyslab, derivative_info):
141153
"""
142154

143155
def __post_init__(self) -> None:
156+
self._validate_callable()
157+
self._validate_compute_derivatives_signature()
158+
159+
def _validate_callable(self) -> None:
144160
if not callable(self.compute_derivatives):
145161
raise AdjointError("CustomVJPConfig.compute_derivatives must be callable.")
146162

163+
def _validate_compute_derivatives_signature(self) -> None:
147164
vjp_sig = inspect.signature(self.compute_derivatives)
148165
vjp_arg_names = list(vjp_sig.parameters.keys())
149166
if len(vjp_arg_names) != 2:
150167
raise AdjointError(
151168
"CustomVJPConfig compute_derivatives function should accept two arguments "
152-
f"and it currently accepts {len(vjp_arg_names)} arguments."
169+
"(target, derivative_info), and it currently accepts "
170+
f"{len(vjp_arg_names)} arguments. The target is the geometry or medium "
171+
"instance selected by path_key, and derivative_info contains the field "
172+
"data and path metadata needed to compute the VJP."
153173
)
154174
if vjp_arg_names[1] != "derivative_info":
155175
raise AdjointError(

0 commit comments

Comments
 (0)