Skip to content

Commit 1d00c95

Browse files
committed
reorganized code
1 parent 0a9f0f5 commit 1d00c95

File tree

2 files changed

+176
-180
lines changed

2 files changed

+176
-180
lines changed

tidy3d/components/geometry/base.py

Lines changed: 129 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -1640,6 +1640,125 @@ def __invert__(self):
16401640
operation="difference", geometry_a=Box(size=(inf, inf, inf)), geometry_b=self
16411641
)
16421642

1643+
def build_box_face_mesh(
1644+
self,
1645+
center: np.ndarray,
1646+
size: np.ndarray,
1647+
axis_normal: int, # 0,1,2 → x,y,z faces
1648+
min_max_index: int, # 0 = − side, 1 = + side
1649+
rotation_matrix: np.ndarray | None = None,
1650+
) -> tuple[DerivativeSurfaceMesh, np.ndarray, np.ndarray]:
1651+
"""Build a mesh for the face of a box, given the center, size, axis normal, and min/max index.
1652+
The mesh is built in the local coordinate system of the box, and then transformed to the global
1653+
coordinate system using the rotation matrix."""
1654+
1655+
if axis_normal == 0:
1656+
canonical_normal = np.array([1.0, 0.0, 0.0])
1657+
elif axis_normal == 1:
1658+
canonical_normal = np.array([0.0, 1.0, 0.0])
1659+
elif axis_normal == 2:
1660+
canonical_normal = np.array([0.0, 0.0, 1.0])
1661+
else:
1662+
raise ValueError("Invalid axis_normal")
1663+
1664+
if min_max_index == 0:
1665+
canonical_normal *= -1.0
1666+
1667+
if rotation_matrix is None:
1668+
rotation_matrix = np.eye(3)
1669+
1670+
n_local = rotation_matrix @ canonical_normal
1671+
n_local = n_local / np.linalg.norm(n_local)
1672+
1673+
def compute_tangential_vectors(
1674+
normal: np.ndarray, eps: float = 1e-8
1675+
) -> tuple[np.ndarray, np.ndarray]:
1676+
"""Compute any two perpendicular tangential vectors t1, t2, given a normal."""
1677+
if abs(normal[0]) > abs(normal[2]):
1678+
t1 = np.array([-normal[1], normal[0], 0.0])
1679+
else:
1680+
t1 = np.array([0.0, -normal[2], normal[1]])
1681+
t1_norm = np.linalg.norm(t1)
1682+
if t1_norm < eps:
1683+
raise ValueError("Degenerate normal vector.")
1684+
t1 = t1 / t1_norm
1685+
t2 = np.cross(normal, t1)
1686+
t2 /= np.linalg.norm(t2)
1687+
return t1, t2
1688+
1689+
t1_local, t2_local = compute_tangential_vectors(n_local)
1690+
1691+
min_bound = np.array(center) - np.array(size) / 2.0
1692+
max_bound = np.array(center) + np.array(size) / 2.0
1693+
bounds_old = np.column_stack((min_bound, max_bound))
1694+
1695+
corners = np.array(
1696+
[
1697+
[bounds_old[0, i], bounds_old[1, j], bounds_old[2, k]]
1698+
for i in (0, 1)
1699+
for j in (0, 1)
1700+
for k in (0, 1)
1701+
]
1702+
)
1703+
1704+
connectivity = {
1705+
0: { # Faces perpendicular to x-axis
1706+
0: [0, 1, 3, 2],
1707+
1: [4, 5, 7, 6],
1708+
},
1709+
1: { # Faces perpendicular to y-axis
1710+
0: [0, 1, 5, 4],
1711+
1: [2, 3, 7, 6],
1712+
},
1713+
2: { # Faces perpendicular to z-axis
1714+
0: [0, 4, 6, 2],
1715+
1: [1, 5, 7, 3],
1716+
},
1717+
}
1718+
1719+
face_indices = connectivity[axis_normal][min_max_index]
1720+
face_corners = corners[face_indices, :]
1721+
1722+
rotated_corners = (rotation_matrix @ face_corners.T).T
1723+
p1, p2, p3, p4 = rotated_corners
1724+
1725+
num_s = _NUM_PTS_DIM_BOX_FACE
1726+
num_t = _NUM_PTS_DIM_BOX_FACE
1727+
s_vals = np.linspace(0, 1, 2 * num_s + 1)[1::2]
1728+
t_vals = np.linspace(0, 1, 2 * num_t + 1)[1::2]
1729+
S, T = np.meshgrid(s_vals, t_vals, indexing="ij")
1730+
1731+
X = (1 - S) * (1 - T) * p1[0] + S * (1 - T) * p2[0] + S * T * p3[0] + (1 - S) * T * p4[0]
1732+
Y = (1 - S) * (1 - T) * p1[1] + S * (1 - T) * p2[1] + S * T * p3[1] + (1 - S) * T * p4[1]
1733+
Z = (1 - S) * (1 - T) * p1[2] + S * (1 - T) * p2[2] + S * T * p3[2] + (1 - S) * T * p4[2]
1734+
1735+
centers = np.column_stack([X.ravel(), Y.ravel(), Z.ravel()])
1736+
1737+
tri1_area = 0.5 * np.linalg.norm(np.cross((p2 - p1), (p3 - p1)))
1738+
tri2_area = 0.5 * np.linalg.norm(np.cross((p4 - p1), (p3 - p1)))
1739+
face_area = tri1_area + tri2_area
1740+
1741+
num_cells = (num_s) * (num_t)
1742+
if num_cells > 0:
1743+
cell_area = face_area / num_cells
1744+
else:
1745+
cell_area = face_area
1746+
1747+
areas = cell_area * np.ones(centers.shape[0])
1748+
1749+
normals = np.tile(n_local, (centers.shape[0], 1))
1750+
perps1 = np.tile(t1_local, (centers.shape[0], 1))
1751+
perps2 = np.tile(t2_local, (centers.shape[0], 1))
1752+
1753+
surface_mesh = DerivativeSurfaceMesh(
1754+
centers=centers,
1755+
areas=areas,
1756+
normals=normals,
1757+
perps1=perps1,
1758+
perps2=perps2,
1759+
)
1760+
return surface_mesh, n_local
1761+
16431762

16441763
""" Abstract subclasses """
16451764

@@ -2564,119 +2683,17 @@ def derivative_face(
25642683
) -> float:
25652684
"""
25662685
Compute the derivative (VJP) with respect to shifting a face of a rotated box,
2567-
using full integration over that face. This version uses bilinear interpolation
2568-
of the four corners to sample interior points.
2569-
"""
2570-
if axis_normal == 0:
2571-
canonical_normal = np.array([1.0, 0.0, 0.0])
2572-
elif axis_normal == 1:
2573-
canonical_normal = np.array([0.0, 1.0, 0.0])
2574-
elif axis_normal == 2:
2575-
canonical_normal = np.array([0.0, 0.0, 1.0])
2576-
else:
2577-
raise ValueError("Invalid axis_normal")
2578-
2579-
if min_max_index == 0:
2580-
canonical_normal *= -1.0
2581-
2582-
if rotation_matrix is None:
2583-
rotation_matrix = np.eye(3)
2584-
2585-
n_local = rotation_matrix @ canonical_normal
2586-
n_local = n_local / np.linalg.norm(n_local)
2587-
2588-
def compute_tangential_vectors(
2589-
normal: np.ndarray, eps: float = 1e-8
2590-
) -> tuple[np.ndarray, np.ndarray]:
2591-
"""Compute any two perpendicular tangential vectors t1, t2, given a normal."""
2592-
if abs(normal[0]) > abs(normal[2]):
2593-
t1 = np.array([-normal[1], normal[0], 0.0])
2594-
else:
2595-
t1 = np.array([0.0, -normal[2], normal[1]])
2596-
t1_norm = np.linalg.norm(t1)
2597-
if t1_norm < eps:
2598-
raise ValueError("Degenerate normal vector.")
2599-
t1 = t1 / t1_norm
2600-
t2 = np.cross(normal, t1)
2601-
t2 /= np.linalg.norm(t2)
2602-
return t1, t2
2603-
2604-
t1_local, t2_local = compute_tangential_vectors(n_local)
2605-
2606-
min_bound = np.array(self.center) - np.array(self.size) / 2.0
2607-
max_bound = np.array(self.center) + np.array(self.size) / 2.0
2608-
bounds_old = np.column_stack((min_bound, max_bound))
2609-
2610-
corners = np.array(
2611-
[
2612-
[bounds_old[0, i], bounds_old[1, j], bounds_old[2, k]]
2613-
for i in (0, 1)
2614-
for j in (0, 1)
2615-
for k in (0, 1)
2616-
]
2686+
using full integration over that face.
2687+
"""
2688+
mesh, _ = self.build_box_face_mesh(
2689+
center=np.asarray(self.center, float),
2690+
size=np.asarray(self.size, float),
2691+
axis_normal=axis_normal,
2692+
min_max_index=min_max_index,
2693+
rotation_matrix=rotation_matrix,
26172694
)
2618-
2619-
connectivity = {
2620-
0: { # Faces perpendicular to x-axis
2621-
0: [0, 1, 3, 2],
2622-
1: [4, 5, 7, 6],
2623-
},
2624-
1: { # Faces perpendicular to y-axis
2625-
0: [0, 1, 5, 4],
2626-
1: [2, 3, 7, 6],
2627-
},
2628-
2: { # Faces perpendicular to z-axis
2629-
0: [0, 4, 6, 2],
2630-
1: [1, 5, 7, 3],
2631-
},
2632-
}
2633-
2634-
face_indices = connectivity[axis_normal][min_max_index]
2635-
face_corners = corners[face_indices, :]
2636-
2637-
rotated_corners = (rotation_matrix @ face_corners.T).T
2638-
p1, p2, p3, p4 = rotated_corners
2639-
2640-
num_s = _NUM_PTS_DIM_BOX_FACE
2641-
num_t = _NUM_PTS_DIM_BOX_FACE
2642-
s_vals = np.linspace(0, 1, 2 * num_s + 1)[1::2]
2643-
t_vals = np.linspace(0, 1, 2 * num_t + 1)[1::2]
2644-
S, T = np.meshgrid(s_vals, t_vals, indexing="ij")
2645-
2646-
X = (1 - S) * (1 - T) * p1[0] + S * (1 - T) * p2[0] + S * T * p3[0] + (1 - S) * T * p4[0]
2647-
Y = (1 - S) * (1 - T) * p1[1] + S * (1 - T) * p2[1] + S * T * p3[1] + (1 - S) * T * p4[1]
2648-
Z = (1 - S) * (1 - T) * p1[2] + S * (1 - T) * p2[2] + S * T * p3[2] + (1 - S) * T * p4[2]
2649-
2650-
centers = np.column_stack([X.ravel(), Y.ravel(), Z.ravel()])
2651-
2652-
tri1_area = 0.5 * np.linalg.norm(np.cross((p2 - p1), (p3 - p1)))
2653-
tri2_area = 0.5 * np.linalg.norm(np.cross((p4 - p1), (p3 - p1)))
2654-
face_area = tri1_area + tri2_area
2655-
2656-
num_cells = (num_s) * (num_t)
2657-
if num_cells > 0:
2658-
cell_area = face_area / num_cells
2659-
else:
2660-
cell_area = face_area
2661-
2662-
areas = cell_area * np.ones(centers.shape[0])
2663-
2664-
normals = np.tile(n_local, (centers.shape[0], 1))
2665-
perps1 = np.tile(t1_local, (centers.shape[0], 1))
2666-
perps2 = np.tile(t2_local, (centers.shape[0], 1))
2667-
2668-
surface_mesh = DerivativeSurfaceMesh(
2669-
centers=centers,
2670-
areas=areas,
2671-
normals=normals,
2672-
perps1=perps1,
2673-
perps2=perps2,
2674-
)
2675-
2676-
grads = derivative_info.grad_surfaces(surface_mesh=surface_mesh)
2677-
vjp_value = np.real(np.sum(grads).item())
2678-
2679-
return vjp_value
2695+
vjp = derivative_info.grad_surfaces(surface_mesh=mesh)
2696+
return float(np.real(np.sum(vjp)))
26802697

26812698

26822699
"""Compound subclasses"""

tidy3d/web/api/autograd/autograd.py

Lines changed: 47 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,77 +1013,56 @@ def postprocess_adj(
10131013
eps_background = None
10141014

10151015
# manually override simulation medium as the background structure
1016-
if not isinstance(structure.geometry, td.Box):
1017-
# auto permittivity detection
1018-
sim_orig = sim_data_orig.simulation
1019-
plane_eps = eps_fwd.monitor.geometry
1020-
1021-
# get permittivity without this structure
1022-
structs_no_struct = list(sim_orig.structures)
1023-
structs_no_struct.pop(structure_index)
1024-
sim_no_structure = sim_orig.updated_copy(structures=structs_no_struct)
1025-
eps_no_structure = sim_no_structure.epsilon(
1026-
box=plane_eps, coord_key="centers", freq=freq_adj
1027-
)
1028-
1029-
eps_in = np.mean(structure.medium.eps_model(freq_adj))
1030-
eps_out = np.mean(sim_data_orig.simulation.medium.eps_model(freq_adj))
1031-
if structure.background_medium:
1032-
eps_background = structure.background_medium.eps_model(freq_adj)
1033-
else:
1034-
eps_background = None
1035-
1036-
# manually override simulation medium as the background structure
1037-
# auto permittivity detection
1038-
sim_orig = sim_data_orig.simulation
1039-
plane_eps = eps_fwd.monitor.geometry
1040-
1041-
# get permittivity without this structure
1042-
structs_no_struct = list(sim_orig.structures)
1043-
structs_no_struct.pop(structure_index)
1044-
sim_no_structure = sim_orig.updated_copy(structures=structs_no_struct)
1045-
eps_no_structure = sim_no_structure.epsilon(
1046-
box=plane_eps, coord_key="centers", freq=freq_adj
1047-
)
1016+
# auto permittivity detection
1017+
sim_orig = sim_data_orig.simulation
1018+
plane_eps = eps_fwd.monitor.geometry
1019+
1020+
# get permittivity without this structure
1021+
structs_no_struct = list(sim_orig.structures)
1022+
structs_no_struct.pop(structure_index)
1023+
sim_no_structure = sim_orig.updated_copy(structures=structs_no_struct)
1024+
eps_no_structure = sim_no_structure.epsilon(
1025+
box=plane_eps, coord_key="centers", freq=freq_adj
1026+
)
10481027

1049-
# get permittivity with structures on top of an infinite version of this structure
1050-
structs_inf_struct = list(sim_orig.structures)[structure_index + 1 :]
1051-
sim_inf_structure = sim_orig.updated_copy(
1052-
structures=structs_inf_struct,
1053-
medium=structure.medium,
1054-
monitors=[],
1055-
)
1056-
eps_inf_structure = sim_inf_structure.epsilon(
1057-
box=plane_eps, coord_key="centers", freq=freq_adj
1058-
)
1028+
# get permittivity with structures on top of an infinite version of this structure
1029+
structs_inf_struct = list(sim_orig.structures)[structure_index + 1 :]
1030+
sim_inf_structure = sim_orig.updated_copy(
1031+
structures=structs_inf_struct,
1032+
medium=structure.medium,
1033+
monitors=[],
1034+
)
1035+
eps_inf_structure = sim_inf_structure.epsilon(
1036+
box=plane_eps, coord_key="centers", freq=freq_adj
1037+
)
10591038

1060-
# get minimum intersection of bounds with structure and sim
1061-
struct_bounds = rmin_struct, rmax_struct = structure.geometry.bounds
1062-
rmin_sim, rmax_sim = sim_data_orig.simulation.bounds
1063-
rmin_intersect = tuple([max(a, b) for a, b in zip(rmin_sim, rmin_struct)])
1064-
rmax_intersect = tuple([min(a, b) for a, b in zip(rmax_sim, rmax_struct)])
1065-
bounds_intersect = (rmin_intersect, rmax_intersect)
1066-
1067-
derivative_info = DerivativeInfo(
1068-
paths=structure_paths,
1069-
E_der_map=E_der_map.field_components,
1070-
D_der_map=D_der_map.field_components,
1071-
E_fwd=E_fwd.field_components,
1072-
E_adj=E_adj.field_components,
1073-
D_fwd=D_fwd.field_components,
1074-
D_adj=D_adj.field_components,
1075-
eps_data=eps_fwd.field_components,
1076-
eps_in=eps_in,
1077-
eps_out=eps_out,
1078-
eps_background=eps_background,
1079-
frequency=freq_adj,
1080-
eps_no_structure=eps_no_structure,
1081-
eps_inf_structure=eps_inf_structure,
1082-
bounds=struct_bounds,
1083-
bounds_intersect=bounds_intersect,
1084-
)
1039+
# get minimum intersection of bounds with structure and sim
1040+
struct_bounds = rmin_struct, rmax_struct = structure.geometry.bounds
1041+
rmin_sim, rmax_sim = sim_data_orig.simulation.bounds
1042+
rmin_intersect = tuple([max(a, b) for a, b in zip(rmin_sim, rmin_struct)])
1043+
rmax_intersect = tuple([min(a, b) for a, b in zip(rmax_sim, rmax_struct)])
1044+
bounds_intersect = (rmin_intersect, rmax_intersect)
1045+
1046+
derivative_info = DerivativeInfo(
1047+
paths=structure_paths,
1048+
E_der_map=E_der_map.field_components,
1049+
D_der_map=D_der_map.field_components,
1050+
E_fwd=E_fwd.field_components,
1051+
E_adj=E_adj.field_components,
1052+
D_fwd=D_fwd.field_components,
1053+
D_adj=D_adj.field_components,
1054+
eps_data=eps_fwd.field_components,
1055+
eps_in=eps_in,
1056+
eps_out=eps_out,
1057+
eps_background=eps_background,
1058+
frequency=freq_adj,
1059+
eps_no_structure=eps_no_structure,
1060+
eps_inf_structure=eps_inf_structure,
1061+
bounds=struct_bounds,
1062+
bounds_intersect=bounds_intersect,
1063+
)
10851064

1086-
vjp_value_map = structure.compute_derivatives(derivative_info)
1065+
vjp_value_map = structure.compute_derivatives(derivative_info)
10871066

10881067
# extract VJPs and put back into sim_fields_vjp AutogradFieldMap
10891068
for structure_path, vjp_value in vjp_value_map.items():

0 commit comments

Comments
 (0)