Skip to content

Commit d64d088

Browse files
committed
Geometry.plot() and Structure().plot() accept the "transpose" argument which swaps the horizontal and vertical axes"
1 parent 19229c8 commit d64d088

File tree

12 files changed

+268
-58
lines changed

12 files changed

+268
-58
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
### Added
11+
- `Geometry.plot()` and `Structure().plot()` now accept the `transpose=True` argument which swaps the horizontal and vertical axes of the plot.
12+
1013
### Changed
1114
- By default, batch downloads will skip files that already exist locally. To force re-downloading and replace existing files, pass the `replace_existing=True` argument to `Batch.load()`, `Batch.download()`, or `BatchData.load()`.
1215
- The `BatchData.load_sim_data()` function now overwrites any previously downloaded simulation files (instead of skipping them).

tests/test_components/test_geometry.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,15 @@
8787
_, AX = plt.subplots()
8888

8989

90-
@pytest.mark.parametrize("component", GEO_TYPES)
91-
def test_plot(component):
92-
_ = component.plot(z=0, ax=AX)
90+
@pytest.mark.parametrize("component, transpose", zip(GEO_TYPES, [True, False]))
91+
def test_plot(component, transpose):
92+
_ = component.plot(z=0, ax=AX, transpose=transpose)
9393
plt.close()
9494

9595

96-
def test_plot_with_units():
97-
_ = BOX.plot(z=0, ax=AX, plot_length_units="nm")
96+
@pytest.mark.parametrize("transpose", [True, False])
97+
def test_plot_with_units(transpose):
98+
_ = BOX.plot(z=0, ax=AX, plot_length_units="nm", transpose=transpose)
9899
plt.close()
99100

100101

@@ -768,13 +769,12 @@ def test_geometry_touching_intersections_plane(x0):
768769

769770

770771
def test_pop_axis():
771-
b = td.Box(size=(1, 1, 1))
772772
for axis in range(3):
773773
coords = (1, 2, 3)
774-
Lz, (Lx, Ly) = b.pop_axis(coords, axis=axis)
775-
_coords = b.unpop_axis(Lz, (Lx, Ly), axis=axis)
774+
Lz, (Lx, Ly) = td.Box.pop_axis(coords, axis=axis)
775+
_coords = td.Box.unpop_axis(Lz, (Lx, Ly), axis=axis)
776776
assert all(c == _c for (c, _c) in zip(coords, _coords))
777-
_Lz, (_Lx, _Ly) = b.pop_axis(_coords, axis=axis)
777+
_Lz, (_Lx, _Ly) = td.Box.pop_axis(_coords, axis=axis)
778778
assert Lz == _Lz
779779
assert Lx == _Lx
780780
assert Ly == _Ly
@@ -939,7 +939,8 @@ def test_to_gds(geometry, tmp_path):
939939
assert len(cell.polygons) == 0
940940

941941

942-
def test_custom_surface_geometry(tmp_path):
942+
@pytest.mark.parametrize("transpose", [True, False])
943+
def test_custom_surface_geometry(transpose, tmp_path):
943944
# create tetrahedron STL
944945
vertices = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]])
945946
faces = np.array([[1, 2, 3], [0, 3, 2], [0, 1, 3], [0, 2, 1]])
@@ -972,9 +973,13 @@ def test_custom_surface_geometry(tmp_path):
972973
assert np.isclose(geom.volume(), 1 / 6)
973974

974975
# test intersections
975-
assert shapely.equals(geom.intersections_plane(x=0), shapely.Polygon([[0, 0], [0, 1], [1, 0]]))
976976
assert shapely.equals(
977-
geom.intersections_plane(z=0.5), shapely.Polygon([[0, 0], [0, 0.5], [0.5, 0]])
977+
geom.intersections_plane(x=0),
978+
shapely.Polygon([[0, 0], [0, 1], [1, 0]]),
979+
)
980+
assert shapely.equals(
981+
geom.intersections_plane(z=0.5),
982+
shapely.Polygon([[0, 0], [0, 0.5], [0.5, 0]]),
978983
)
979984

980985
# test inside
@@ -983,7 +988,7 @@ def test_custom_surface_geometry(tmp_path):
983988

984989
# test plot
985990
_, ax = plt.subplots()
986-
_ = geom.plot(z=0.1, ax=ax)
991+
_ = geom.plot(z=0.1, ax=ax, transpose=transpose)
987992
plt.close()
988993

989994
# test inconsistent winding
@@ -1033,7 +1038,7 @@ def test_custom_surface_geometry(tmp_path):
10331038
boundary_spec=td.BoundarySpec.all_sides(td.PML()),
10341039
)
10351040
_, ax = plt.subplots()
1036-
_ = sim.plot(y=0, ax=ax)
1041+
_ = sim.plot(y=0, ax=ax, transpose=transpose)
10371042
plt.close()
10381043

10391044
# allow small triangles

tests/test_components/test_utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""Tests objects shared by multiple components."""
2+
3+
from __future__ import annotations
4+
5+
import random
6+
7+
import pytest
8+
from shapely.geometry import Point
9+
10+
from tidy3d.components.utils import pop_axis_and_swap, shape_swap_xy, unpop_axis_and_swap
11+
12+
13+
@pytest.mark.parametrize("transpose", [True, False])
14+
def test_pop_axis_and_swap(transpose):
15+
for axis in range(3):
16+
coords = (1, 2, 3)
17+
Lz, (Lx, Ly) = pop_axis_and_swap(coords, axis=axis, transpose=transpose)
18+
_coords = unpop_axis_and_swap(Lz, (Lx, Ly), axis=axis, transpose=transpose)
19+
assert all(c == _c for (c, _c) in zip(coords, _coords))
20+
_Lz, (_Lx, _Ly) = pop_axis_and_swap(_coords, axis=axis, transpose=transpose)
21+
assert Lz == _Lz
22+
assert Lx == _Lx
23+
assert Ly == _Ly
24+
25+
26+
def test_shape_swap_xy():
27+
p_orig = Point(random.random(), random.random())
28+
p_new = shape_swap_xy(p_orig)
29+
assert (p_new.coords[0][0], p_new.coords[0][1]) == (p_orig.coords[0][1], p_orig.coords[0][0])

tidy3d/components/data/unstructured/triangular.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
SpatialDataArray,
2424
)
2525
from tidy3d.components.types import ArrayLike, Ax, Axis, Bound
26+
from tidy3d.components.utils import pop_axis_and_swap
2627
from tidy3d.components.viz import add_ax_if_none, equal_aspect, plot_params_grid
2728
from tidy3d.constants import inf
2829
from tidy3d.exceptions import DataError
@@ -570,9 +571,10 @@ def does_cover(self, bounds: Bound) -> bool:
570571

571572
""" Plotting """
572573

573-
@property
574-
def _triangulation_obj(self) -> Triangulation:
574+
def _triangulation_obj(self, transpose: bool = False) -> Triangulation:
575575
"""Matplotlib triangular representation of the grid to use in plotting."""
576+
if transpose:
577+
return Triangulation(self.points[:, 1], self.points[:, 0], self.cells)
576578
return Triangulation(self.points[:, 0], self.points[:, 1], self.cells)
577579

578580
@equal_aspect
@@ -589,6 +591,7 @@ def plot(
589591
shading: Literal["gourand", "flat"] = "gouraud",
590592
cbar_kwargs: Optional[dict] = None,
591593
pcolor_kwargs: Optional[dict] = None,
594+
transpose: bool = False,
592595
) -> Ax:
593596
"""Plot the data field and/or the unstructured grid.
594597
@@ -616,6 +619,8 @@ def plot(
616619
Additional parameters passed to colorbar object.
617620
pcolor_kwargs: Dict = {}
618621
Additional parameters passed to ax.tripcolor()
622+
transpose : bool = False
623+
Swap horizontal and vertical axes. (This overrides the default ascending axis order)
619624
620625
Returns
621626
-------
@@ -639,7 +644,7 @@ def plot(
639644
f"{self._values_coords_dict} before plotting."
640645
)
641646
plot_obj = ax.tripcolor(
642-
self._triangulation_obj,
647+
self._triangulation_obj(transpose=transpose),
643648
self.values.data.ravel(),
644649
shading=shading,
645650
cmap=cmap,
@@ -657,14 +662,15 @@ def plot(
657662
# plot grid if requested
658663
if grid:
659664
ax.triplot(
660-
self._triangulation_obj,
665+
self._triangulation_obj(transpose=transpose),
661666
color=plot_params_grid.edgecolor,
662667
linewidth=plot_params_grid.linewidth,
663668
)
664669

665670
# set labels and titles
666-
ax_labels = ["x", "y", "z"]
667-
normal_axis_name = ax_labels.pop(self.normal_axis)
671+
normal_axis_name, ax_labels = pop_axis_and_swap(
672+
"xyz", self.normal_axis, transpose=transpose
673+
)
668674
ax.set_xlabel(ax_labels[0])
669675
ax.set_ylabel(ax_labels[1])
670676
ax.set_title(f"{normal_axis_name} = {self.normal_pos}")

0 commit comments

Comments
 (0)