Skip to content

Commit ec85831

Browse files
committed
feat: make a FieldPlotter for CalculusPatch
1 parent e487da6 commit ec85831

File tree

2 files changed

+39
-22
lines changed

2 files changed

+39
-22
lines changed

sumpy/point_calculus.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929

3030
from pytools import memoize_method, obj_array
3131

32+
from sumpy.visualization import FieldPlotter
33+
3234

3335
if TYPE_CHECKING:
3436
from collections.abc import Callable, Sequence
@@ -345,6 +347,9 @@ def norm(self,
345347
else:
346348
raise ValueError("unsupported norm")
347349

350+
def make_field_plotter(self) -> FieldPlotter:
351+
return FieldPlotter(self.center, self.h, points=self._points_shaped)
352+
348353
def plot_nodes(self) -> None:
349354
if self.dim != 2:
350355
raise ValueError(f"cannot plot {self.dim}d fields")
@@ -357,6 +362,11 @@ def plot_nodes(self) -> None:
357362
"o")
358363

359364
def plot(self, f: Array1D[np.floating[Any]]) -> None:
365+
from warnings import warn
366+
warn(f"Calling '{type(self).__name__}.plot' is deprecated. Use "
367+
f"'{type(self).__name__}.make_field_plotter' instead, which also "
368+
"supports 3d fields.", DeprecationWarning, stacklevel=2)
369+
360370
if self.dim != 2:
361371
raise ValueError(f"cannot plot {self.dim}d fields")
362372

sumpy/visualization.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -96,52 +96,59 @@ def make_field_plotter_from_bbox(
9696

9797
class FieldPlotter:
9898
"""
99+
.. autoattribute:: dimensions
100+
.. autoattribute:: npoints
101+
.. autoattribute:: points
102+
99103
.. automethod:: set_matplotlib_limits
100104
.. automethod:: show_scalar_in_matplotlib
101105
.. automethod:: show_scalar_in_mayavi
102106
.. automethod:: write_vtk_file
103107
"""
104108

105109
dimensions: int
110+
npoints: int
111+
points: onp.Array2D[np.floating[Any]]
112+
106113
a: onp.Array1D[np.floating[Any]]
107114
b: onp.Array1D[np.floating[Any]]
108-
109-
nd_points: onp.Array2D[np.floating[Any]]
110-
points: onp.Array2D[np.floating[Any]]
111-
npoints: int
115+
nd_points: onp.ArrayND[np.floating[Any]]
112116

113117
def __init__(self,
114118
center: onp.ToArray1D[np.floating[Any]],
115119
extent: float | onp.Array1D[np.floating[Any]] = 1,
116-
npoints: int | tuple[int, ...] = 1000) -> None:
120+
npoints: int | tuple[int, ...] = 1000,
121+
points: onp.ArrayND[np.floating[Any]] | None = None) -> None:
117122
center = np.asarray(center)
118123
dim, = cast("tuple[int]", center.shape)
119124

120125
self.dimensions = dim
121126
self.a = a = center - 0.5 * extent
122127
self.b = b = center + 0.5 * extent
123128

124-
from numbers import Number
125-
if isinstance(npoints, (int, Number)):
126-
npoints = dim*(npoints,)
129+
if points is None:
130+
from numbers import Number
131+
if isinstance(npoints, (int, Number)):
132+
npoints = dim*(npoints,)
133+
else:
134+
if len(npoints) != dim:
135+
raise ValueError("length of npoints must match dimension")
136+
137+
for i in range(dim):
138+
if npoints[i] == 1:
139+
a[i] = center[i]
140+
141+
mgrid_index = tuple(
142+
slice(a[i], b[i], 1j*npoints[i])
143+
for i in range(dim))
144+
mgrid = np.mgrid[mgrid_index]
127145
else:
128-
if len(npoints) != dim:
129-
raise ValueError("length of npoints must match dimension")
130-
131-
for i in range(dim):
132-
if npoints[i] == 1:
133-
a[i] = center[i]
134-
135-
mgrid_index = tuple(
136-
slice(a[i], b[i], 1j*npoints[i])
137-
for i in range(dim))
138-
mgrid = np.mgrid[mgrid_index]
146+
mgrid = points
139147

140148
# (axis, point x idx, point y idx, ...)
141149
self.nd_points = mgrid
142-
143-
self.points = self.nd_points.reshape(dim, -1).copy()
144-
self.npoints = np.prod(npoints)
150+
self.points = mgrid.reshape(dim, -1).copy()
151+
self.npoints = mgrid.size
145152

146153
def _get_nontrivial_dims(self) -> onp.Array1D[np.bool_]:
147154
return np.array(self.nd_points.shape[1:]) != 1

0 commit comments

Comments
 (0)