Skip to content

Commit 76ab553

Browse files
Allow transformed groups to be flattened (#2050)
Related to: - flexcompute/tidy3d-core#751 - flexcompute/tidy3d-core#750 Signed-off-by: Lucas Heitzmann Gabrielli <[email protected]>
1 parent 164b4b8 commit 76ab553

File tree

4 files changed

+58
-5
lines changed

4 files changed

+58
-5
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
### Added
1111
- Autograd support for local field projections using `FieldProjectionKSpaceMonitor`.
12+
- Function `components.geometry.utils.flatten_groups` now also flattens transformed groups when requested.
1213

1314
### Fixed
1415
- Regression in local field projection leading to incorrect results for `far_field_approx=True`.

tests/test_components/test_geometry.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,33 @@ def test_flattening():
524524
for g in flat
525525
)
526526

527+
t0 = np.array([[2, 0, 0, 0], [3, 2, 0, 0], [1, 0, 2, 0], [0, 0, 0, 1.0]])
528+
g0 = td.Sphere(radius=1)
529+
t1 = np.array([[2, 0, 5, 0], [0, 1, 0, 0], [-1, 0, 1, 0], [0, 0, 0, 1.0]])
530+
g1 = td.Box(size=(1, 2, 3))
531+
flat = list(
532+
flatten_groups(
533+
td.Transformed(
534+
transform=t0,
535+
geometry=td.ClipOperation(
536+
operation="union",
537+
geometry_a=g0,
538+
geometry_b=td.Transformed(transform=t1, geometry=g1),
539+
),
540+
),
541+
flatten_transformed=True,
542+
)
543+
)
544+
assert len(flat) == 2
545+
546+
assert isinstance(flat[0], td.Transformed)
547+
assert flat[0].geometry == g0
548+
assert np.allclose(flat[0].transform, t0)
549+
550+
assert isinstance(flat[1], td.Transformed)
551+
assert flat[1].geometry == g1
552+
assert np.allclose(flat[1].transform, t0 @ t1)
553+
527554

528555
def test_geometry_traversal():
529556
geometries = list(traverse_geometries(td.Box(size=(1, 1, 1))))

tidy3d/components/geometry/utils.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
from math import isclose
6-
from typing import Tuple, Union
6+
from typing import Optional, Tuple, Union
77

88
import numpy as np
99

@@ -24,17 +24,25 @@
2424
]
2525

2626

27-
def flatten_groups(*geometries: GeometryType, flatten_nonunion_type: bool = False) -> GeometryType:
27+
def flatten_groups(
28+
*geometries: GeometryType,
29+
flatten_nonunion_type: bool = False,
30+
flatten_transformed: bool = False,
31+
transform: Optional[MatrixReal4x4] = None,
32+
) -> GeometryType:
2833
"""Iterates over all geometries, flattening groups and unions.
2934
3035
Parameters
3136
----------
3237
*geometries : GeometryType
3338
Geometries to flatten.
34-
3539
flatten_nonunion_type : bool = False
3640
If ``False``, only flatten geometry unions (and ``GeometryGroup``). If ``True``, flatten
3741
all clip operations.
42+
flatten_transformed : bool = False
43+
If ``True``, ``Transformed`` groups are flattened into individual transformed geometries.
44+
transform : Optional[MatrixReal4x4]
45+
Accumulated transform from parents. Only used when ``flatten_transformed`` is ``True``.
3846
3947
Yields
4048
------
@@ -44,7 +52,10 @@ def flatten_groups(*geometries: GeometryType, flatten_nonunion_type: bool = Fals
4452
for geometry in geometries:
4553
if isinstance(geometry, base.GeometryGroup):
4654
yield from flatten_groups(
47-
*geometry.geometries, flatten_nonunion_type=flatten_nonunion_type
55+
*geometry.geometries,
56+
flatten_nonunion_type=flatten_nonunion_type,
57+
flatten_transformed=flatten_transformed,
58+
transform=transform,
4859
)
4960
elif isinstance(geometry, base.ClipOperation) and (
5061
flatten_nonunion_type or geometry.operation == "union"
@@ -53,7 +64,21 @@ def flatten_groups(*geometries: GeometryType, flatten_nonunion_type: bool = Fals
5364
geometry.geometry_a,
5465
geometry.geometry_b,
5566
flatten_nonunion_type=flatten_nonunion_type,
67+
flatten_transformed=flatten_transformed,
68+
transform=transform,
69+
)
70+
elif flatten_transformed and isinstance(geometry, base.Transformed):
71+
new_transform = geometry.transform
72+
if transform is not None:
73+
new_transform = np.matmul(transform, new_transform)
74+
yield from flatten_groups(
75+
geometry.geometry,
76+
flatten_nonunion_type=flatten_nonunion_type,
77+
flatten_transformed=flatten_transformed,
78+
transform=new_transform,
5679
)
80+
elif flatten_transformed and transform is not None:
81+
yield base.Transformed(geometry=geometry, transform=transform)
5782
else:
5883
yield geometry
5984

tidy3d/components/scene.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def _validate_num_geometries(cls, val):
116116
return val
117117

118118
for i, structure in enumerate(val):
119-
for geometry in flatten_groups(structure.geometry):
119+
for geometry in flatten_groups(structure.geometry, flatten_transformed=True):
120120
count = sum(
121121
1
122122
for g in traverse_geometries(geometry)

0 commit comments

Comments
 (0)