33import inspect
44from collections .abc import Sequence
55from dataclasses import dataclass
6- from typing import TYPE_CHECKING , Callable , NamedTuple , Optional , Union
6+ from typing import TYPE_CHECKING , Any , Callable , NamedTuple , Optional , Union
77
88import numpy as np
99
1010from tidy3d .exceptions import AdjointError
1111
1212if TYPE_CHECKING :
13+ from tidy3d .components .autograd .derivative_utils import DerivativeInfo
14+ from tidy3d .components .autograd .types import PathType
1315 from tidy3d .components .autograd import AutogradFieldMap
1416 from tidy3d .components .geometry .utils import GeometryType
1517 from tidy3d .components .medium import MediumType
1618 from tidy3d .components .simulation import Simulation
19+ from tidy3d .components .structure import Structure
1720 from tidy3d .components .types import ArrayLike
1821
1922
@@ -45,10 +48,10 @@ def cylinder_vjp(parameters, derivative_info):
4548 )
4649 """
4750
48- create : Callable
51+ create : Callable [[ ArrayLike ], Structure ]
4952 """Function that creates the structure from static ``parameters``."""
5053
51- compute_derivatives : Callable
54+ compute_derivatives : Callable [[ ArrayLike , DerivativeInfo ], dict [ PathType , Any ]]
5255 """Function that computes numerical gradients for ``("numerical", index, param_i)`` paths.
5356 Signature: ``compute_derivatives(parameters, derivative_info) -> dict[path, gradient]``.
5457 """
@@ -57,11 +60,18 @@ def cylinder_vjp(parameters, derivative_info):
5760 """1D parameter vector consumed by ``create`` and ``compute_derivatives``."""
5861
5962 def __post_init__ (self ) -> None :
63+ self ._validate_callables ()
64+ self ._validate_create_signature ()
65+ self ._validate_compute_derivatives_signature ()
66+ self ._validate_parameters ()
67+
68+ def _validate_callables (self ) -> None :
6069 if not callable (self .create ):
6170 raise AdjointError ("NumericalStructureConfig.create must be callable." )
6271 if not callable (self .compute_derivatives ):
6372 raise AdjointError ("NumericalStructureConfig.compute_derivatives must be callable." )
6473
74+ def _validate_create_signature (self ) -> None :
6575 create_sig = inspect .signature (self .create )
6676 create_arg_names = list (create_sig .parameters .keys ())
6777 if len (create_arg_names ) != 1 :
@@ -71,6 +81,7 @@ def __post_init__(self) -> None:
7181 f"accepts { len (create_arg_names )} arguments."
7282 )
7383
84+ def _validate_compute_derivatives_signature (self ) -> None :
7485 vjp_sig = inspect .signature (self .compute_derivatives )
7586 vjp_arg_names = list (vjp_sig .parameters .keys ())
7687 if len (vjp_arg_names ) != 2 :
@@ -87,6 +98,7 @@ def __post_init__(self) -> None:
8798 f"{ vjp_arg_names [1 ]} but it should be derivative_info."
8899 )
89100
101+ def _validate_parameters (self ) -> None :
90102 try :
91103 array_params = np .asarray (self .parameters )
92104 except Exception as exc :
@@ -126,7 +138,7 @@ def polyslab_vjp(polyslab, derivative_info):
126138 Can be an index or a geometry/medium type (expanded to matching indices).
127139 """
128140
129- compute_derivatives : Callable
141+ compute_derivatives : Callable [[ GeometryType | MediumType , DerivativeInfo ], dict [ PathType , Any ]]
130142 """Function for computing the targeted vjp value. The function should accept the geometry or medium in the
131143 structure depending on if this is a geometry or medium path (see path_key) as the first argument. The second
132144 argument should accept a DerivativeInfo object that contains important for computing the gradient. The function
@@ -141,15 +153,23 @@ def polyslab_vjp(polyslab, derivative_info):
141153 """
142154
143155 def __post_init__ (self ) -> None :
156+ self ._validate_callable ()
157+ self ._validate_compute_derivatives_signature ()
158+
159+ def _validate_callable (self ) -> None :
144160 if not callable (self .compute_derivatives ):
145161 raise AdjointError ("CustomVJPConfig.compute_derivatives must be callable." )
146162
163+ def _validate_compute_derivatives_signature (self ) -> None :
147164 vjp_sig = inspect .signature (self .compute_derivatives )
148165 vjp_arg_names = list (vjp_sig .parameters .keys ())
149166 if len (vjp_arg_names ) != 2 :
150167 raise AdjointError (
151168 "CustomVJPConfig compute_derivatives function should accept two arguments "
152- f"and it currently accepts { len (vjp_arg_names )} arguments."
169+ "(target, derivative_info), and it currently accepts "
170+ f"{ len (vjp_arg_names )} arguments. The target is the geometry or medium "
171+ "instance selected by path_key, and derivative_info contains the field "
172+ "data and path metadata needed to compute the VJP."
153173 )
154174 if vjp_arg_names [1 ] != "derivative_info" :
155175 raise AdjointError (
0 commit comments