feat(tidy3d): FXC-4607-autograd-for-clip-operation#3104
feat(tidy3d): FXC-4607-autograd-for-clip-operation#3104marcorudolphflex wants to merge 1 commit intodevelopfrom
Conversation
|
@greptile |
tests/test_components/autograd/numerical/test_autograd_clip_operation_numerical.py
Outdated
Show resolved
Hide resolved
tests/test_components/autograd/numerical/test_autograd_clip_operation_numerical.py
Outdated
Show resolved
Hide resolved
tests/test_components/autograd/numerical/test_autograd_clip_operation_numerical.py
Outdated
Show resolved
Hide resolved
tests/test_components/autograd/numerical/test_autograd_clip_operation_numerical.py
Outdated
Show resolved
Hide resolved
91de1d8 to
d6cccf8
Compare
tests/test_components/autograd/numerical/test_autograd_clip_operation_numerical.py
Outdated
Show resolved
Hide resolved
d6cccf8 to
ec881d3
Compare
tests/test_components/autograd/numerical/test_autograd_clip_operation_numerical.py
Outdated
Show resolved
Hide resolved
ec881d3 to
764c0f0
Compare
Diff CoverageDiff: origin/develop...HEAD, staged and unstaged changes
Summary
tidy3d/components/geometry/base.pyLines 1494-1502 1494 derivative_info: DerivativeInfo,
1495 clip_operation: Optional[ClipOperationContext] = None,
1496 ) -> AutogradFieldMap:
1497 """Compute the adjoint derivatives for this object."""
! 1498 raise NotImplementedError(
1499 f"Can't compute derivative for clipped 'Geometry': '{type(self)}'."
1500 )
1501
1502 def _as_union(self) -> list[Geometry]:Lines 2654-2662 2654 finally:
2655 derivative_info.paths = original_paths
2656 gradient_key = ("mesh_dataset", "surface_mesh")
2657 if gradient_key not in mesh_vjps:
! 2658 return {}
2659
2660 triangle_grads = mesh_vjps[gradient_key]
2661 vertex_grads = np.asarray(self._accumulate_vertex_gradients(triangle_grads), dtype=float)
2662 vjps_faces = np.zeros((2, 3), dtype=vertex_grads.dtype)Lines 3300-3308 3300 if which == "geometry_a":
3301 return self.geometry_a
3302 if which == "geometry_b":
3303 return self.geometry_b
! 3304 raise ValueError(f"Unsupported geometry key '{which}'.")
3305
3306 def _other_geometry(self, which: ClipGeometryKey) -> Geometry:
3307 """Return the opposing geometry for ``which``."""
3308 return self.geometry_b if which == "geometry_a" else self.geometry_aLines 3312-3320 3312 """Convert sample points to a 2D array and track whether input was a single point."""
3313 arr = np.asarray(points, dtype=float)
3314 single_point = arr.ndim == 1
3315 if arr.size == 0:
! 3316 arr = arr.reshape((0, 3))
3317 else:
3318 arr = arr.reshape((-1, 3))
3319 return arr, single_pointLines 3332-3340 3332 if operation == "difference":
3333 return (~mask) if which == "geometry_a" else mask.copy()
3334 if operation == "symmetric_difference":
3335 return np.ones_like(mask, dtype=bool)
! 3336 raise ValueError(f"Unsupported clip operation '{operation}'.")
3337
3338 @staticmethod
3339 def _clip_flip_mask(
3340 operation: ClipOperationType, which: ClipGeometryKey, inside_mask: np.ndarrayLines 3365-3374 3365 """Return (use_mask, flip_mask, single_point) for sample points."""
3366
3367 points_arr, single_point = self._points_to_array(points)
3368 if points_arr.size == 0:
! 3369 empty = np.zeros(0, dtype=bool)
! 3370 return empty, empty, single_point
3371
3372 other = self._other_geometry(which)
3373 inside_other = np.asarray(
3374 other.inside(points_arr[:, 0], points_arr[:, 1], points_arr[:, 2]), dtype=boolLines 3380-3389 3380 def sample_points_should_use(
3381 self, which: ClipGeometryKey, points: ArrayLike
3382 ) -> Union[bool, NDArray[np.bool_]]:
3383 """Return a mask indicating which samples contribute to the gradient."""
! 3384 use_mask, _, single_point = self._clip_masks_for_points(which, points)
! 3385 return bool(use_mask[0]) if single_point else use_mask
3386
3387 def sample_normals_should_flip(
3388 self, which: ClipGeometryKey, points: ArrayLike
3389 ) -> Union[bool, NDArray[np.bool_]]:Lines 3387-3396 3387 def sample_normals_should_flip(
3388 self, which: ClipGeometryKey, points: ArrayLike
3389 ) -> Union[bool, NDArray[np.bool_]]:
3390 """Return a mask indicating which sample normals require flipping."""
! 3391 _, flip_mask, single_point = self._clip_masks_for_points(which, points)
! 3392 return bool(flip_mask[0]) if single_point else flip_mask
3393
3394 def intersections_tilted_plane(
3395 self,
3396 normal: Coordinate,Lines 3603-3618 3603 interpolators = derivative_info.interpolators or derivative_info.create_interpolators()
3604
3605 for field_path in derivative_info.paths:
3606 if not field_path:
! 3607 continue
3608 which, *geo_path = field_path
3609 if which not in ("geometry_a", "geometry_b"):
! 3610 raise ValueError(
3611 "ClipOperation derivatives are only defined for 'geometry_a' or 'geometry_b'."
3612 )
3613 if not geo_path:
! 3614 raise ValueError("ClipOperation derivative path must specify a geometry field.")
3615 geometry = self._geometry_from_key(which)
3616 geo_info = derivative_info.updated_copy(
3617 paths=[tuple(geo_path)],
3618 bounds=geometry.bounds,Lines 3628-3636 3628 else:
3629 clip_context = (context, clip_operation)
3630 vjps_geo = geometry._compute_derivatives_via_mesh(geo_info, clip_operation=clip_context)
3631 if len(vjps_geo) != 1:
! 3632 raise AssertionError("Expected a single gradient value for each geometry field.")
3633 grad_vjps[field_path] = vjps_geo.popitem()[1]
3634
3635 return grad_vjpsLines 3863-3871 3863 self,
3864 derivative_info: DerivativeInfo,
3865 clip_operation: Optional[ClipOperationContext] = None,
3866 ) -> AutogradFieldMap:
! 3867 return self._compute_derivatives(derivative_info, clip_operation=clip_operation)
3868
3869 def _compute_derivatives(
3870 self,
3871 derivative_info: DerivativeInfo,Lines 3891-3899 3891 deep=False,
3892 interpolators=interpolators,
3893 )
3894 if clip_operation is not None:
! 3895 vjp_dict_geo = geo._compute_derivatives_via_mesh(
3896 geo_info, clip_operation=clip_operation
3897 )
3898 else:
3899 vjp_dict_geo = geo._compute_derivatives(geo_info)tidy3d/components/geometry/mesh.pyLines 875-883 875 and isinstance(entry[0], base.ClipOperation)
876 ):
877 contexts.append(entry)
878 if not contexts:
! 879 raise ValueError("Invalid ClipOperation context provided.")
880 return contexts
881
882 def _apply_clip_filters_single(
883 self,Lines 896-904 896 points = np.asarray(samples["points"], dtype=config.adjoint.gradient_dtype_float)
897 normals = np.asarray(samples["normals"], dtype=config.adjoint.gradient_dtype_float)
898 total_points = points.shape[0]
899 if total_points == 0:
! 900 return samples
901
902 shift = max(float(config.adjoint.edge_clip_tolerance), 1e-9)
903 probe_points = points - normals * shift
904 inside_mask = np.asarray(Lines 911-919 911 ).reshape(-1)
912
913 use_mask, flip_mask = clip_obj._clip_masks_from_inside(which, inside_mask)
914 if use_mask.size != total_points:
! 915 raise ValueError("ClipOperation sample mask has incorrect shape.")
916 if not np.any(use_mask):
917 return {key: np.asarray(value[:0]).copy() for key, value in samples.items()}
918
919 filtered = {key: np.asarray(value[use_mask]).copy() for key, value in samples.items()}Lines 918-926 918
919 filtered = {key: np.asarray(value[use_mask]).copy() for key, value in samples.items()}
920
921 if flip_mask.size != total_points:
! 922 raise ValueError("ClipOperation normal flip mask has incorrect shape.")
923 flip_mask = flip_mask[use_mask]
924 if np.any(flip_mask):
925 flip_signs = np.where(flip_mask[:, None], -1.0, 1.0)
926 filtered["normals"] = filtered["normals"] * flip_signsLines 937-945 937 clip_obj, which = clip_operation
938 points = samples["points"]
939 total_points = points.shape[0]
940 if total_points == 0:
! 941 return {key: np.asarray(value[:0]).copy() for key, value in samples.items()}
942
943 use_mask, flip_mask, _ = clip_obj._clip_masks_for_points(which, points)
944 if not np.any(use_mask):
945 return {key: np.asarray(value[:0]).copy() for key, value in samples.items()}Lines 957-973 957 def _prepare_clip_geometry(other: base.Geometry) -> base.Geometry:
958 """Return a TriangleMesh suitable for geometric clipping operations."""
959
960 if not isinstance(other, TriangleMesh):
! 961 return other
962
963 try:
964 tri_mesh = other.trimesh
! 965 except Exception:
! 966 return other
967
968 if not tri_mesh.is_volume:
! 969 raise ValueError(
970 "ClipOperation requires volume TriangleMesh geometry for clip filtering."
971 )
972
973 return otherLines 991-999 991
992 spacing = max(float(spacing), np.finfo(float).eps)
993 triangles_arr = np.asarray(triangles, dtype=dtype)
994 if triangles_arr.size == 0:
! 995 return self._empty_sample_result(dtype)
996
997 edges01 = triangles_arr[:, 1, :] - triangles_arr[:, 0, :]
998 edges02 = triangles_arr[:, 2, :] - triangles_arr[:, 0, :]
999 edges12 = triangles_arr[:, 2, :] - triangles_arr[:, 1, :]Lines 1088-1096 1088
1089 for group_idx, group_subdiv in enumerate(unique_subdiv):
1090 group_faces = np.flatnonzero(inverse == group_idx)
1091 if group_faces.size == 0:
! 1092 continue
1093 barycentric = self._get_barycentric_samples(int(group_subdiv), dtype)
1094 self._append_barycentric_group(
1095 samples=samples,
1096 barycentric=barycentric,Lines 1115-1128 1115 "Some triangles from the mesh lie outside the simulation bounds - this may lead to inaccurate gradients."
1116 )
1117 warned = True
1118 if not clipped:
! 1119 continue
1120
1121 for tri_clip in clipped:
1122 area_clip, _ = self._triangle_area_and_normal(tri_clip)
1123 if area_clip <= AREA_SIZE_THRESHOLD:
! 1124 continue
1125
1126 edge_lengths = (
1127 np.linalg.norm(tri_clip[1] - tri_clip[0]),
1128 np.linalg.norm(tri_clip[2] - tri_clip[1]),Lines 1193-1201 1193 for start, end in segments:
1194 vec = end - start
1195 length = float(np.linalg.norm(vec))
1196 if length <= tol:
! 1197 continue
1198
1199 subdivisions = max(1, int(np.ceil(length / spacing)))
1200 t_vals = (np.arange(subdivisions, dtype=dtype) + 0.5) / subdivisions
1201 sample_points = start[None, :] + t_vals[:, None] * vec[None, :]Lines 1210-1224 1210 coords <= max_bound, axis=1
1211 )
1212
1213 if not np.all(inside_mask) and not warned:
! 1214 log.warning(
1215 "Some triangles from the mesh lie outside the simulation bounds - this may lead to inaccurate gradients."
1216 )
! 1217 warned = True
1218
1219 if not np.any(inside_mask):
! 1220 continue
1221
1222 sample_points = sample_points[inside_mask]
1223 bary_inside = barycentric[inside_mask]
1224 n_inside = sample_points.shape[0]Lines 1259-1270 1259 @staticmethod
1260 def _empty_sample_result(dtype: np.dtype) -> dict[str, np.ndarray]:
1261 """Return the default empty sampling dictionary."""
1262
! 1263 zeros_vec = np.zeros((0, 3), dtype=dtype)
! 1264 zeros_scalar = np.zeros((0,), dtype=dtype)
! 1265 zeros_faces = np.zeros((0,), dtype=int)
! 1266 return {
1267 "points": zeros_vec,
1268 "normals": zeros_vec.copy(),
1269 "perps1": zeros_vec.copy(),
1270 "perps2": zeros_vec.copy(),Lines 1325-1333 1325 ) -> dict[str, np.ndarray]:
1326 """Concatenate accumulated sampling data or return an empty structure."""
1327
1328 if not samples.points:
! 1329 return TriangleMesh._empty_sample_result(dtype)
1330
1331 return {
1332 "points": np.concatenate(samples.points, axis=0),
1333 "normals": np.concatenate(samples.normals, axis=0),Lines 1450-1458 1450 ) -> list[np.ndarray]:
1451 """Clip a polygon with an axis-aligned plane."""
1452
1453 if not polygon:
! 1454 return []
1455
1456 result: list[np.ndarray] = []
1457 prev = polygon[-1]
1458 prev_val = prev[axis]Lines 1487-1495 1487 v0 = float(p0[axis]) - bound
1488 v1 = float(p1[axis]) - bound
1489 denom = v1 - v0
1490 if abs(denom) <= tol:
! 1491 return p0.copy()
1492 t = -v0 / denom
1493 t = float(np.clip(t, 0.0, 1.0))
1494 return p0 + t * (p1 - p0)Lines 1503-1511 1503 inside = np.all(vertices >= (sim_min - tol), axis=1) & np.all(
1504 vertices <= (sim_max + tol), axis=1
1505 )
1506 if np.all(inside):
! 1507 return [triangle], False
1508
1509 polygon = [triangle[0].copy(), triangle[1].copy(), triangle[2].copy()]
1510 clipped_flag = True
1511 for axis in range(3):Lines 1510-1524 1510 clipped_flag = True
1511 for axis in range(3):
1512 polygon = cls._clip_polygon_with_plane(polygon, axis, sim_min[axis], False, tol)
1513 if not polygon:
! 1514 return [], True
1515 polygon = cls._clip_polygon_with_plane(polygon, axis, sim_max[axis], True, tol)
1516 if not polygon:
! 1517 return [], True
1518
1519 if len(polygon) < 3:
! 1520 return [], True
1521
1522 triangles: list[NDArray] = []
1523 anchor = polygon[0]
1524 for idx in range(1, len(polygon) - 1):tidy3d/components/geometry/polyslab.pyLines 1475-1483 1475 ) -> AutogradFieldMap:
1476 """Compute adjoint derivatives via mesh-based sampling."""
1477
1478 if not self._mesh_derivatives_supported():
! 1479 return self._zero_derivative_map(derivative_info)
1480
1481 dtype = config.adjoint.gradient_dtype_float
1482 vertices_arr = np.asarray(self.vertices, dtype=dtype)
1483 slab_bounds_arr = np.asarray(self.slab_bounds, dtype=dtype)Lines 1483-1491 1483 slab_bounds_arr = np.asarray(self.slab_bounds, dtype=dtype)
1484 sidewall_angle_val = np.array(self.sidewall_angle, dtype=dtype)
1485
1486 if vertices_arr.shape[0] < 3:
! 1487 return self._zero_derivative_map(derivative_info)
1488
1489 mesh_vertices, base_polygon, top_polygon = self._mesh_vertex_positions(
1490 vertices=vertices_arr,
1491 slab_bounds=slab_bounds_arr,Lines 1494-1502 1494
1495 faces, partitions = self._ensure_mesh_faces(base_polygon, top_polygon)
1496
1497 if mesh_vertices.size == 0 or faces.size == 0:
! 1498 return self._zero_derivative_map(derivative_info)
1499
1500 from .mesh import TriangleMesh
1501
1502 mesh = TriangleMesh.from_vertices_faces(mesh_vertices, faces)Lines 1508-1516 1508 finally:
1509 derivative_info.paths = original_paths
1510 gradient_key = ("mesh_dataset", "surface_mesh")
1511 if gradient_key not in mesh_vjps:
! 1512 return self._zero_derivative_map(derivative_info)
1513
1514 triangle_grads = mesh_vjps[gradient_key]
1515 num_vertices = mesh_vertices.shape[0]
1516 base_slice = partitions["base"]Lines 1622-1641 1622
1623 dtype = config.adjoint.gradient_dtype_float
1624
1625 def empty_result() -> tuple[NDArray, NDArray, NDArray]:
! 1626 verts3d = np.zeros((0, 3), dtype=dtype)
! 1627 polys = np.zeros((0, 2), dtype=dtype)
! 1628 return verts3d, polys, polys
1629
1630 reference_polygon = PolySlab._proper_vertices(vertices)
1631 if reference_polygon.shape[0] < 3:
! 1632 return empty_result()
1633
1634 bounds_vals = np.array([getval(slab_bounds[0]), getval(slab_bounds[1])], dtype=float)
1635 length_val = bounds_vals[1] - bounds_vals[0]
1636 if length_val <= fp_eps:
! 1637 return empty_result()
1638
1639 zmin = np.maximum(slab_bounds[0], -LARGE_NUMBER)
1640 zmax = np.minimum(slab_bounds[1], LARGE_NUMBER)
1641 finite_length = zmax - zminLines 1644-1664 1644 tan_val = np.tan(sidewall_angle)
1645 offset = np.where(np.isclose(tan_val, 0.0), 0.0, -half_length * tan_val)
1646
1647 if self.reference_plane == "bottom":
! 1648 middle_polygon = PolySlab._shift_vertices(reference_polygon, offset)[0]
1649 elif self.reference_plane == "top":
! 1650 middle_polygon = PolySlab._shift_vertices(reference_polygon, -offset)[0]
1651 else:
1652 middle_polygon = reference_polygon
1653
1654 if self.reference_plane == "bottom":
! 1655 base_polygon = reference_polygon
1656 else:
1657 base_polygon = PolySlab._shift_vertices(middle_polygon, -offset)[0]
1658
1659 if self.reference_plane == "top":
! 1660 top_polygon = reference_polygon
1661 else:
1662 top_polygon = PolySlab._shift_vertices(middle_polygon, offset)[0]
1663
1664 planar = np.vstack((base_polygon, top_polygon))Lines 1683-1703 1683 ) -> tuple[NDArray[np.int_], dict[str, slice]]:
1684 """Construct (and cache) the triangle indices for the PolySlab mesh."""
1685
1686 if self._mesh_faces is not None:
! 1687 return self._mesh_faces
1688
1689 def empty_faces() -> tuple[NDArray[np.int_], dict[str, slice]]:
! 1690 faces = np.zeros((0, 3), dtype=int)
! 1691 empty = slice(0, 0)
! 1692 partitions = {"base": empty, "top": empty, "side": empty}
! 1693 self._mesh_faces = (faces, partitions)
! 1694 return self._mesh_faces
1695
1696 n_base = int(base_polygon.shape[0])
1697 n_top = int(top_polygon.shape[0])
1698 if n_base < 3 or n_top < 3 or n_base != n_top:
! 1699 return empty_faces()
1700
1701 try:
1702 base_triangles = triangulation.triangulate(base_polygon)
1703 if math.isclose(self.sidewall_angle, 0):Lines 1702-1713 1702 base_triangles = triangulation.triangulate(base_polygon)
1703 if math.isclose(self.sidewall_angle, 0):
1704 top_triangles = base_triangles
1705 else:
! 1706 top_triangles = triangulation.triangulate(top_polygon)
! 1707 except Exception as exc:
! 1708 log.debug("Failed to triangulate 'PolySlab' mesh faces: %s", exc)
! 1709 return empty_faces()
1710
1711 base_faces = [[a, b, c] for c, b, a in base_triangles]
1712 top_shift = n_base
1713 top_faces = [[top_shift + a, top_shift + b, top_shift + c] for a, b, c in top_triangles]Lines 1731-1743 1731 ) -> NDArray:
1732 """Aggregate per-triangle gradients into per-vertex values."""
1733
1734 if triangle_grads.size == 0 or faces.size == 0:
! 1735 length = int(num_vertices or 0)
! 1736 return np.zeros((length, 3), dtype=triangle_grads.dtype)
1737
1738 if num_vertices is None:
! 1739 num_vertices = int(faces.max() + 1)
1740
1741 vertex_grads = np.zeros((num_vertices, 3), dtype=triangle_grads.dtype)
1742 for face_index, face in enumerate(faces):
1743 for local_idx, vertex_idx in enumerate(face):Lines 1764-1777 1764 grad_vertices = grad_vertices_side + grad_vertices_caps
1765 grad_vertices *= self._planar_orientation_sign()
1766 grad_bounds *= self._planar_orientation_sign()
1767 if self._is_2d_slice(derivative_info):
! 1768 slab_thickness = float(getval(slab_bounds[1]) - getval(slab_bounds[0]))
! 1769 if not np.isfinite(slab_thickness) or slab_thickness <= fp_eps:
! 1770 thickness = 1.0
1771 else:
! 1772 thickness = slab_thickness
! 1773 grad_vertices /= thickness
1774
1775 sim_min, sim_max = map(np.asarray, derivative_info.simulation_bounds)
1776 intersect_min, intersect_max = map(np.asarray, derivative_info.bounds_intersect)
1777 is_2d = np.isclose(intersect_max[self.axis] - intersect_min[self.axis], 0.0)Lines 1775-1783 1775 sim_min, sim_max = map(np.asarray, derivative_info.simulation_bounds)
1776 intersect_min, intersect_max = map(np.asarray, derivative_info.bounds_intersect)
1777 is_2d = np.isclose(intersect_max[self.axis] - intersect_min[self.axis], 0.0)
1778 if is_2d:
! 1779 grad_bounds = np.zeros_like(grad_bounds)
1780 interpolators = derivative_info.interpolators or derivative_info.create_interpolators(
1781 dtype=config.adjoint.gradient_dtype_float
1782 )
1783 grad_angle_exact = self._compute_derivative_sidewall_angle(Lines 1792-1805 1792 for path in derivative_info.paths:
1793 if path == ("vertices",):
1794 results[path] = grad_vertices
1795 elif path == ("sidewall_angle",):
! 1796 results[path] = float(grad_angle_exact)
1797 elif path[0] == "slab_bounds":
1798 idx = int(path[1])
1799 results[path] = float(grad_bounds[idx])
1800 else:
! 1801 raise ValueError(f"No derivative defined w.r.t. 'PolySlab' field '{path}'.")
1802
1803 return results
1804
1805 def _planar_orientation_sign(self) -> float:Lines 1819-1837 1819
1820 def _zero_derivative_map(self, derivative_info: DerivativeInfo) -> AutogradFieldMap:
1821 """Return a zero-valued derivative map for requested fields."""
1822
! 1823 result: AutogradFieldMap = {}
! 1824 for path in derivative_info.paths:
! 1825 if path == ("vertices",):
! 1826 result[path] = np.zeros_like(self.vertices)
! 1827 elif path == ("sidewall_angle",):
! 1828 result[path] = 0.0
! 1829 elif path[0] == "slab_bounds":
! 1830 result[path] = 0.0
1831 else:
! 1832 raise ValueError(f"No derivative defined w.r.t. 'PolySlab' field '{path}'.")
! 1833 return result
1834
1835 def _mesh_parameter_gradients(
1836 self,
1837 vertex_grads: NDArray,tidy3d/components/geometry/primitives.pyLines 363-379 363
364 self._validate_derivative_paths(derivative_info)
365
366 if not derivative_info.paths:
! 367 return {}
368
369 radius = float(get_static(self.radius))
370 if radius == 0.0:
! 371 log.warning(
372 "Sphere gradients cannot be computed for zero radius; gradients are zero.",
373 log_once=True,
374 )
! 375 return self._zero_derivative_map(derivative_info)
376
377 grid_cfg = config.adjoint
378 wvl_mat = discretization_wavelength(derivative_info, "sphere")
379 target_edge = max(wvl_mat / grid_cfg.points_per_wavelength, np.finfo(float).eps)Lines 379-387 379 target_edge = max(wvl_mat / grid_cfg.points_per_wavelength, np.finfo(float).eps)
380 triangles, _ = self._triangulated_surface(max_edge_length=target_edge)
381 triangles = np.asarray(triangles, dtype=grid_cfg.gradient_dtype_float)
382 if triangles.size == 0:
! 383 return self._zero_derivative_map(derivative_info)
384
385 mesh = TriangleMesh.from_triangles(triangles)
386 original_paths = derivative_info.paths
387 derivative_info.paths = [("mesh_dataset", "surface_mesh")]Lines 391-403 391 derivative_info.paths = original_paths
392
393 gradient_key = ("mesh_dataset", "surface_mesh")
394 if gradient_key not in mesh_vjps:
! 395 return self._zero_derivative_map(derivative_info)
396
397 triangle_grads = np.asarray(mesh_vjps[gradient_key], dtype=float)
398 if triangle_grads.size == 0:
! 399 return self._zero_derivative_map(derivative_info)
400
401 center = np.asarray(self.center, dtype=float)
402 relative = triangles - center
403 norms = np.linalg.norm(relative, axis=2, keepdims=True)Lines 410-422 410 result: AutogradFieldMap = {}
411 for path in derivative_info.paths:
412 if path == ("radius",):
413 result[path] = grad_radius
! 414 elif path[0] == "center":
! 415 idx = int(path[1])
! 416 result[path] = float(grad_center[idx])
417 else:
! 418 raise ValueError(f"No derivative defined w.r.t. 'Sphere' field '{path}'.")
419
420 return result
421
422 def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap:Lines 449-457 449 sim_extents = sim_max - sim_min
450 collapsed_indices = np.flatnonzero(np.isclose(sim_extents, 0.0, atol=tol))
451 if collapsed_indices.size:
452 if collapsed_indices.size > 1:
! 453 return self._zero_derivative_map(derivative_info)
454 axis_idx = int(collapsed_indices[0])
455 plane_value = float(sim_min[axis_idx])
456 return self._compute_derivatives_collapsed_axis(
457 derivative_info=derivative_info,Lines 467-475 467 norms = np.where(norms == 0, 1, norms)
468 normals = verts_centered / norms
469
470 if vertices.size == 0:
! 471 return self._zero_derivative_map(derivative_info)
472
473 # get vertex weights
474 faces = np.asarray(trimesh_obj.faces, dtype=int)
475 face_areas = np.asarray(trimesh_obj.area_faces, dtype=grid_cfg.gradient_dtype_float)Lines 485-493 485 vertices[:, valid_axes] >= (sim_min - tol)[valid_axes], axis=1
486 ) & np.all(vertices[:, valid_axes] <= (sim_max + tol)[valid_axes], axis=1)
487
488 if not np.any(inside_mask):
! 489 return self._zero_derivative_map(derivative_info)
490
491 points = vertices[inside_mask]
492 normals_sel = normals[inside_mask]
493 perp1_sel = perp1[inside_mask]Lines 885-893 885 update_kwargs["interpolators"] = derivative_info.interpolators
886
887 derivative_info_polyslab = derivative_info.updated_copy(**update_kwargs)
888 if clip_operation is not None:
! 889 vjps_polyslab = polyslab._compute_derivatives_via_mesh(
890 derivative_info_polyslab, clip_operation=clip_operation
891 )
892 else:
893 vjps_polyslab = polyslab._compute_derivatives(derivative_info_polyslab) |
764c0f0 to
2d71ce9
Compare
2d71ce9 to
05bf86f
Compare
2e16f93 to
2a43f2b
Compare
2a43f2b to
5d11c0b
Compare
5d11c0b to
12bfb6b
Compare
12bfb6b to
e64e1f5
Compare
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable autofix in the Cursor dashboard.
| def _compute_derivatives_via_mesh( | ||
| self, | ||
| derivative_info: DerivativeInfo, | ||
| clip_operation: base.Optional[ClipOperationContext] = None, |
There was a problem hiding this comment.
Typo: base.Optional instead of Optional in type hint
Low Severity
Cylinder._compute_derivatives_via_mesh uses base.Optional[ClipOperationContext] instead of Optional[ClipOperationContext]. While this works at runtime due to from __future__ import annotations making annotations lazy strings, it's inconsistent with the adjacent method on line 845 which correctly uses Optional[base.ClipOperationContext]. This accesses Optional through the base geometry module rather than using the locally imported Optional from typing.
|
migrated |


Greptile Summary
ClipOperationgeometries like unions, intersections, and differences to enable gradient-based optimization workflowsBox,Cylinder,PolySlab,TriangleMesh) for handling complex surface interactions during clipped operationsImportant Files Changed
Confidence score: 3/5
Context used (5)
dashboard- Remove commented-out or obsolete code; rely on version control for history. (source)dashboard- Remove temporary debugging code (print() calls), commented-out code, and other workarounds before fi... (source)dashboard- Update the CHANGELOG.md file when making changes that affect user-facing functionality or fix bugs. (source)dashboard- Use changelog categories correctly: "Fixed" for bug fixes, "Changed" for modifications to existing f... (source)dashboard- Assert the direct outcome of an operation rather than a side effect (like a log message) when possib... (source)Note
High Risk
Touches core geometry/autograd derivative paths and refactors
TriangleMeshsurface sampling/clipping logic, so gradient correctness and numerical stability could regress across multiple geometry types and boolean operations.Overview
Adds autograd support for
ClipOperation(union/intersection/difference/xor) by introducing a mesh-based derivative pathway (_compute_derivatives_via_mesh) and routing clipped gradients through operand-specific sampling masks and normal flipping.Implements mesh-based derivative backends for key geometries (
Box,Sphere,PolySlab,Cylinder,TriangleMesh) and updatesGeometryGroupto propagate clip context;TriangleMeshderivative sampling is heavily refactored (vectorized sampling, triangle clipping to sim bounds, and nested clip-context filtering).Expands test coverage with new unit + numerical finite-difference validation for
ClipOperation, updates an existing test to assert gradients are produced (not an error), and hardens numerical artifact directory naming to avoid filesystem path-length issues.Written by Cursor Bugbot for commit e64e1f5. This will update automatically on new commits. Configure here.