Skip to content

Commit 33ddf11

Browse files
committed
Add function to obtain cell values in UnstructuredGridDataset
1 parent 95b12fd commit 33ddf11

File tree

4 files changed

+103
-0
lines changed

4 files changed

+103
-0
lines changed

tests/test_data/test_datasets.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,3 +667,75 @@ def test_triangular_dataset_uniform():
667667

668668
tri_grid = tri_grid.updated_copy(values=tri_grid_values)
669669
assert not tri_grid.is_uniform
670+
671+
672+
def test_cell_values():
673+
"""Test whether the cell values are correctly calculated"""
674+
import tidy3d as td
675+
676+
# start with a triangle grid
677+
tri_grid_points = td.PointDataArray(
678+
[[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]],
679+
dims=("index", "axis"),
680+
)
681+
682+
tri_grid_cells = td.CellDataArray(
683+
[[0, 1, 2], [1, 2, 3]],
684+
dims=("cell_index", "vertex_index"),
685+
)
686+
687+
tri_grid_values = td.IndexedVoltageDataArray(
688+
[[0.0, 0.0], [0, 0], [3, -3], [3, -3]],
689+
coords=dict(index=np.arange(4), voltage=[-1, 1]),
690+
name="test",
691+
)
692+
693+
tri_grid = td.TriangularGridDataset(
694+
normal_axis=1,
695+
normal_pos=0,
696+
points=tri_grid_points,
697+
cells=tri_grid_cells,
698+
values=tri_grid_values,
699+
)
700+
701+
cell_values = tri_grid.get_cell_values(voltage=-1)
702+
cell_vols = tri_grid.get_cell_volumes()
703+
assert np.dot(cell_values, cell_vols) == 1.5
704+
705+
# Now repeat for a tet mesh
706+
tet_grid_points = td.PointDataArray(
707+
[
708+
[0.0, 0.0, 0.0],
709+
[1.0, 0.0, 0.0],
710+
[0.0, 1.0, 0.0],
711+
[1.0, 1.0, 0.0],
712+
[0.0, 0.0, 1.0],
713+
[1.0, 0.0, 1.0],
714+
[0.0, 1.0, 1.0],
715+
[1.0, 1.0, 1.0],
716+
],
717+
dims=("index", "axis"),
718+
)
719+
720+
tet_grid_cells = td.CellDataArray(
721+
[[0, 1, 3, 7], [0, 2, 7, 3], [0, 2, 6, 7], [0, 4, 7, 6], [0, 4, 5, 7], [0, 1, 7, 5]],
722+
dims=("cell_index", "vertex_index"),
723+
)
724+
725+
tet_grid_values = td.IndexedDataArray(
726+
[0.0, 0.0, 0.0, 0.0, 3.0, 3.0, 3.0, 3.0], coords=dict(index=np.arange(8)), name="test_tet"
727+
)
728+
729+
tet_grid = td.TetrahedralGridDataset(
730+
points=tet_grid_points,
731+
cells=tet_grid_cells,
732+
values=tet_grid_values,
733+
)
734+
735+
# this will fail since we now have a single field (voltage isn't a coordinate)
736+
with pytest.raises(KeyError):
737+
_ = tet_grid.get_cell_values(voltage=1)
738+
739+
cell_values = tet_grid.get_cell_values()
740+
cell_vols = tet_grid.get_cell_volumes()
741+
assert np.dot(cell_values, cell_vols) == 1.5

tidy3d/components/data/unstructured/base.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,20 @@ def _get_values_from_vtk(
641641

642642
return values
643643

644+
def get_cell_values(self, **kwargs):
645+
"""This function returns the cell values for the fields stored in the UnstructuredGridDataset.
646+
If multiple fields are stored per point, like in an IndexedVoltageDataArray, cell values
647+
will be provided for each of the fields unless a selection argument is provided, e.g., voltage=0.2
648+
"""
649+
650+
values = self.values.sel(**kwargs)
651+
652+
return values[self.cells].mean(dim="vertex_index").values
653+
654+
@abstractmethod
655+
def get_cell_volumes(self):
656+
"""Get the volumes associated to each cell."""
657+
644658
""" Grid operations """
645659

646660
@requires_vtk

tidy3d/components/data/unstructured/tetrahedral.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,3 +347,12 @@ def sel(
347347
return self_after_non_spatial_sel.interp(x=x, y=y, z=z)
348348

349349
return self_after_non_spatial_sel
350+
351+
def get_cell_volumes(self):
352+
"""Get the volumes associated to each cell in the grid"""
353+
v0 = self.points[self.cells.sel(vertex_index=0)]
354+
e01 = self.points[self.cells.sel(vertex_index=1)] - v0
355+
e02 = self.points[self.cells.sel(vertex_index=2)] - v0
356+
e03 = self.points[self.cells.sel(vertex_index=3)] - v0
357+
358+
return np.abs(np.sum(np.cross(e01, e02) * e03, axis=1)) / 6

tidy3d/components/data/unstructured/triangular.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,3 +657,11 @@ def plot(
657657
ax.set_ylabel(ax_labels[1])
658658
ax.set_title(f"{normal_axis_name} = {self.normal_pos}")
659659
return ax
660+
661+
def get_cell_volumes(self):
662+
"""Get areas associated to each cell of the grid."""
663+
v0 = self.points[self.cells.sel(vertex_index=0)]
664+
e01 = self.points[self.cells.sel(vertex_index=1)] - v0
665+
e02 = self.points[self.cells.sel(vertex_index=2)] - v0
666+
667+
return 0.5 * np.abs(np.cross(e01, e02))

0 commit comments

Comments
 (0)