Skip to content

Commit 68063e0

Browse files
committed
Geometry.plot() and Structure().plot() accept the "transpose" argument which swaps the horizontal and vertical axes"
1 parent 53bedec commit 68063e0

File tree

12 files changed

+595
-132
lines changed

12 files changed

+595
-132
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
### Added
11+
- `Geometry.plot()` and `Structure().plot()` now accept the `transpose=True` argument which swaps the horizontal and vertical axes of the plot.
12+
1013
### Changed
1114
- By default, batch downloads will skip files that already exist locally. To force re-downloading and replace existing files, pass the `replace_existing=True` argument to `Batch.load()`, `Batch.download()`, or `BatchData.load()`.
1215
- The `BatchData.load_sim_data()` function now overwrites any previously downloaded simulation files (instead of skipping them).

tests/test_components/test_geometry.py

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,15 @@
8787
_, AX = plt.subplots()
8888

8989

90-
@pytest.mark.parametrize("component", GEO_TYPES)
91-
def test_plot(component):
92-
_ = component.plot(z=0, ax=AX)
90+
@pytest.mark.parametrize("component, transpose", zip(GEO_TYPES, [True, False]))
91+
def test_plot(component, transpose):
92+
_ = component.plot(z=0, ax=AX, transpose=transpose)
9393
plt.close()
9494

9595

96-
def test_plot_with_units():
97-
_ = BOX.plot(z=0, ax=AX, plot_length_units="nm")
96+
@pytest.mark.parametrize("transpose", [True, False])
97+
def test_plot_with_units(transpose):
98+
_ = BOX.plot(z=0, ax=AX, plot_length_units="nm", transpose=transpose)
9899
plt.close()
99100

100101

@@ -206,11 +207,11 @@ def test_array_to_vertices():
206207
assert np.all(np.array(vertices) == np.array(vertices2))
207208

208209

209-
@pytest.mark.parametrize("component", GEO_TYPES)
210-
def test_intersections_plane(component):
211-
assert len(component.intersections_plane(z=0.2)) > 0
212-
assert len(component.intersections_plane(x=0.2)) > 0
213-
assert len(component.intersections_plane(x=10000)) == 0
210+
@pytest.mark.parametrize("component, transpose", zip(GEO_TYPES, [True, False]))
211+
def test_intersections_plane(component, transpose):
212+
assert len(component.intersections_plane(z=0.2, transpose=transpose)) > 0
213+
assert len(component.intersections_plane(x=0.2, transpose=transpose)) > 0
214+
assert len(component.intersections_plane(x=10000, transpose=transpose)) == 0
214215

215216

216217
def test_intersections_plane_inf():
@@ -768,36 +769,36 @@ def test_geometry_touching_intersections_plane(x0):
768769

769770

770771
def test_pop_axis():
771-
b = td.Box(size=(1, 1, 1))
772772
for axis in range(3):
773773
coords = (1, 2, 3)
774-
Lz, (Lx, Ly) = b.pop_axis(coords, axis=axis)
775-
_coords = b.unpop_axis(Lz, (Lx, Ly), axis=axis)
774+
Lz, (Lx, Ly) = td.Box.pop_axis(coords, axis=axis)
775+
_coords = td.Box.unpop_axis(Lz, (Lx, Ly), axis=axis)
776776
assert all(c == _c for (c, _c) in zip(coords, _coords))
777-
_Lz, (_Lx, _Ly) = b.pop_axis(_coords, axis=axis)
777+
_Lz, (_Lx, _Ly) = td.Box.pop_axis(_coords, axis=axis)
778778
assert Lz == _Lz
779779
assert Lx == _Lx
780780
assert Ly == _Ly
781781

782782

783-
def test_2b_box_intersections():
783+
@pytest.mark.parametrize("transpose", [True, False])
784+
def test_2b_box_intersections(transpose):
784785
plane = td.Box(size=(1, 4, 0))
785786
box1 = td.Box(size=(1, 1, 1))
786787
box2 = td.Box(size=(1, 1, 1), center=(3, 0, 0))
787788

788-
result = plane.intersections_with(box1)
789+
result = plane.intersections_with(box1, transpose=transpose)
789790
assert len(result) == 1
790791
assert result[0].geom_type == "Polygon"
791-
assert len(plane.intersections_with(box2)) == 0
792+
assert len(plane.intersections_with(box2, transpose=transpose)) == 0
792793

793794
with pytest.raises(ValidationError):
794-
_ = box1.intersections_with(box2)
795+
_ = box1.intersections_with(box2, transpose=transpose)
795796

796-
assert len(box1.intersections_2dbox(plane)) == 1
797-
assert len(box2.intersections_2dbox(plane)) == 0
797+
assert len(box1.intersections_2dbox(plane, transpose=transpose)) == 1
798+
assert len(box2.intersections_2dbox(plane, transpose=transpose)) == 0
798799

799800
with pytest.raises(ValidationError):
800-
_ = box2.intersections_2dbox(box1)
801+
_ = box2.intersections_2dbox(box1, transpose=transpose)
801802

802803

803804
def test_polyslab_merge():
@@ -939,7 +940,8 @@ def test_to_gds(geometry, tmp_path):
939940
assert len(cell.polygons) == 0
940941

941942

942-
def test_custom_surface_geometry(tmp_path):
943+
@pytest.mark.parametrize("transpose", [True, False])
944+
def test_custom_surface_geometry(transpose, tmp_path):
943945
# create tetrahedron STL
944946
vertices = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]])
945947
faces = np.array([[1, 2, 3], [0, 3, 2], [0, 1, 3], [0, 2, 1]])
@@ -972,9 +974,13 @@ def test_custom_surface_geometry(tmp_path):
972974
assert np.isclose(geom.volume(), 1 / 6)
973975

974976
# test intersections
975-
assert shapely.equals(geom.intersections_plane(x=0), shapely.Polygon([[0, 0], [0, 1], [1, 0]]))
976977
assert shapely.equals(
977-
geom.intersections_plane(z=0.5), shapely.Polygon([[0, 0], [0, 0.5], [0.5, 0]])
978+
geom.intersections_plane(x=0, transpose=transpose),
979+
shapely.Polygon([[0, 0], [0, 1], [1, 0]]),
980+
)
981+
assert shapely.equals(
982+
geom.intersections_plane(z=0.5, transpose=transpose),
983+
shapely.Polygon([[0, 0], [0, 0.5], [0.5, 0]]),
978984
)
979985

980986
# test inside
@@ -983,7 +989,7 @@ def test_custom_surface_geometry(tmp_path):
983989

984990
# test plot
985991
_, ax = plt.subplots()
986-
_ = geom.plot(z=0.1, ax=ax)
992+
_ = geom.plot(z=0.1, ax=ax, transpose=transpose)
987993
plt.close()
988994

989995
# test inconsistent winding
@@ -1033,7 +1039,7 @@ def test_custom_surface_geometry(tmp_path):
10331039
boundary_spec=td.BoundarySpec.all_sides(td.PML()),
10341040
)
10351041
_, ax = plt.subplots()
1036-
_ = sim.plot(y=0, ax=ax)
1042+
_ = sim.plot(y=0, ax=ax, transpose=transpose)
10371043
plt.close()
10381044

10391045
# allow small triangles

tests/test_components/test_utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""Tests objects shared by multiple components."""
2+
3+
from __future__ import annotations
4+
5+
import random
6+
7+
import pytest
8+
from shapely.geometry import Point
9+
10+
from tidy3d.components.utils import pop_axis_and_swap, shape_swap_xy, unpop_axis_and_swap
11+
12+
13+
@pytest.mark.parametrize("transpose", [True, False])
14+
def test_pop_axis_and_swap(transpose):
15+
for axis in range(3):
16+
coords = (1, 2, 3)
17+
Lz, (Lx, Ly) = pop_axis_and_swap(coords, axis=axis, transpose=transpose)
18+
_coords = unpop_axis_and_swap(Lz, (Lx, Ly), axis=axis, transpose=transpose)
19+
assert all(c == _c for (c, _c) in zip(coords, _coords))
20+
_Lz, (_Lx, _Ly) = pop_axis_and_swap(_coords, axis=axis, transpose=transpose)
21+
assert Lz == _Lz
22+
assert Lx == _Lx
23+
assert Ly == _Ly
24+
25+
26+
def test_shape_swap_xy():
27+
p_orig = Point(random.random(), random.random())
28+
p_new = shape_swap_xy(p_orig)
29+
assert (p_new.coords[0][0], p_new.coords[0][1]) == (p_orig.coords[0][1], p_orig.coords[0][0])

tidy3d/components/data/unstructured/triangular.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
SpatialDataArray,
2424
)
2525
from tidy3d.components.types import ArrayLike, Ax, Axis, Bound
26+
from tidy3d.components.utils import pop_axis_and_swap
2627
from tidy3d.components.viz import add_ax_if_none, equal_aspect, plot_params_grid
2728
from tidy3d.constants import inf
2829
from tidy3d.exceptions import DataError
@@ -570,9 +571,10 @@ def does_cover(self, bounds: Bound) -> bool:
570571

571572
""" Plotting """
572573

573-
@property
574-
def _triangulation_obj(self) -> Triangulation:
574+
def _triangulation_obj(self, transpose: bool = False) -> Triangulation:
575575
"""Matplotlib triangular representation of the grid to use in plotting."""
576+
if transpose:
577+
return Triangulation(self.points[:, 1], self.points[:, 0], self.cells)
576578
return Triangulation(self.points[:, 0], self.points[:, 1], self.cells)
577579

578580
@equal_aspect
@@ -589,6 +591,7 @@ def plot(
589591
shading: Literal["gourand", "flat"] = "gouraud",
590592
cbar_kwargs: Optional[dict] = None,
591593
pcolor_kwargs: Optional[dict] = None,
594+
transpose: bool = False,
592595
) -> Ax:
593596
"""Plot the data field and/or the unstructured grid.
594597
@@ -616,6 +619,8 @@ def plot(
616619
Additional parameters passed to colorbar object.
617620
pcolor_kwargs: Dict = {}
618621
Additional parameters passed to ax.tripcolor()
622+
transpose : bool = False
623+
Swap horizontal and vertical axes. (This overrides the default ascending axis order)
619624
620625
Returns
621626
-------
@@ -639,7 +644,7 @@ def plot(
639644
f"{self._values_coords_dict} before plotting."
640645
)
641646
plot_obj = ax.tripcolor(
642-
self._triangulation_obj,
647+
self._triangulation_obj(transpose=transpose),
643648
self.values.data.ravel(),
644649
shading=shading,
645650
cmap=cmap,
@@ -657,14 +662,15 @@ def plot(
657662
# plot grid if requested
658663
if grid:
659664
ax.triplot(
660-
self._triangulation_obj,
665+
self._triangulation_obj(transpose=transpose),
661666
color=plot_params_grid.edgecolor,
662667
linewidth=plot_params_grid.linewidth,
663668
)
664669

665670
# set labels and titles
666-
ax_labels = ["x", "y", "z"]
667-
normal_axis_name = ax_labels.pop(self.normal_axis)
671+
normal_axis_name, ax_labels = pop_axis_and_swap(
672+
"xyz", self.normal_axis, transpose=transpose
673+
)
668674
ax.set_xlabel(ax_labels[0])
669675
ax.set_ylabel(ax_labels[1])
670676
ax.set_title(f"{normal_axis_name} = {self.normal_pos}")

0 commit comments

Comments
 (0)