Skip to content

Commit 27c2883

Browse files
cvanelterenCopilotbeckermr
authored
Add curved_quiver — Curved Vector Field Arrows for 2D Plots (#361)
* Add curved curved_quiver * add unittests * update docstrings * black formatting * rm dup import * Update ultraplot/axes/plot.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update ultraplot/axes/plot.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update ultraplot/axes/plot.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update ultraplot/tests/test_plot.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * mv import up * refactor and move down * black formatting * update tests with new api * add one test as image comp * Update ultraplot/axes/plot.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update ultraplot/axes/plot.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update ultraplot/axes/plot.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update ultraplot/axes/plot.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update ultraplot/tests/test_plot.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * black formatting * inline the termination to make it slightly more compact * mv curved quiver plot to 'plot_types' and update tests' * mv parameters to rcsetup * add type hinting * rm dup docstring * rm unused imports * Update ultraplot/axes/plot_types/curved_quiver.py Co-authored-by: Matthew R. Becker <beckermr@users.noreply.github.com> * Apply suggestion from @beckermr Co-authored-by: Matthew R. Becker <beckermr@users.noreply.github.com> * Apply suggestion from @beckermr Co-authored-by: Matthew R. Becker <beckermr@users.noreply.github.com> * Apply suggestion from @beckermr Co-authored-by: Matthew R. Becker <beckermr@users.noreply.github.com> * rename private classes * more renaming * more renaming --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Matthew R. Becker <beckermr@users.noreply.github.com>
1 parent d514aa1 commit 27c2883

File tree

5 files changed

+853
-15
lines changed

5 files changed

+853
-15
lines changed

ultraplot/axes/plot.py

Lines changed: 269 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
from typing import Any, Union, Iterable, Optional
1414

15-
from typing import Any, Union
1615
from collections.abc import Callable
1716
from collections.abc import Iterable
1817

@@ -33,6 +32,7 @@
3332
import matplotlib as mpl
3433
from packaging import version
3534
import numpy as np
35+
from typing import Optional, Union, Any
3636
import numpy.ma as ma
3737

3838
from .. import colors as pcolors
@@ -170,7 +170,45 @@
170170
docstring._snippet_manager["plot.args_1d_shared"] = _args_1d_shared_docstring
171171
docstring._snippet_manager["plot.args_2d_shared"] = _args_2d_shared_docstring
172172

173+
_curved_quiver_docstring = """
174+
Draws curved vector field arrows (streamlines with arrows) for 2D vector fields.
173175
176+
Parameters
177+
----------
178+
x, y : 1D or 2D arrays
179+
Grid coordinates.
180+
u, v : 2D arrays
181+
Vector components.
182+
color : color or 2D array, optional
183+
Streamline color.
184+
density : float or (float, float), optional
185+
Controls the closeness of streamlines.
186+
grains : int or (int, int), optional
187+
Number of seed points in x and y.
188+
linewidth : float or 2D array, optional
189+
Width of streamlines.
190+
cmap, norm : optional
191+
Colormap and normalization for array colors.
192+
arrowsize : float, optional
193+
Arrow size scaling.
194+
arrowstyle : str, optional
195+
Arrow style specification.
196+
transform : optional
197+
Matplotlib transform.
198+
zorder : float, optional
199+
Z-order for lines/arrows.
200+
start_points : (N, 2) array, optional
201+
Starting points for streamlines.
202+
203+
Returns
204+
-------
205+
CurvedQuiverSet
206+
Container with attributes:
207+
- lines: LineCollection of streamlines
208+
- arrows: PatchCollection of arrows
209+
"""
210+
211+
docstring._snippet_manager["plot.curved_quiver"] = _curved_quiver_docstring
174212
# Auto colorbar and legend docstring
175213
_guide_docstring = """
176214
colorbar : bool, int, or str, optional
@@ -1499,22 +1537,237 @@ class PlotAxes(base.Axes):
14991537
Implements all plotting overrides.
15001538
"""
15011539

1502-
def __init__(self, *args, **kwargs):
1540+
@docstring._snippet_manager
1541+
def curved_quiver(
1542+
self,
1543+
x: np.ndarray,
1544+
y: np.ndarray,
1545+
u: np.ndarray,
1546+
v: np.ndarray,
1547+
linewidth: Optional[float] = None,
1548+
color: Optional[Union[str, Any]] = None,
1549+
cmap: Optional[Any] = None,
1550+
norm: Optional[Any] = None,
1551+
arrowsize: Optional[float] = None,
1552+
arrowstyle: Optional[str] = None,
1553+
transform: Optional[Any] = None,
1554+
zorder: Optional[int] = None,
1555+
start_points: Optional[np.ndarray] = None,
1556+
scale: Optional[float] = None,
1557+
grains: Optional[int] = None,
1558+
density: Optional[int] = None,
1559+
arrow_at_end: Optional[bool] = None,
1560+
):
15031561
"""
1504-
Parameters
1505-
----------
1506-
*args, **kwargs
1507-
Passed to `ultraplot.axes.Axes`.
1562+
%(plot.curved_quiver)s
1563+
1564+
Notes
1565+
-----
1566+
The implementation of this function is based on the `dfm_tools` repository.
1567+
Original file: https://github.com/Deltares/dfm_tools/blob/829e76f48ebc42460aae118cc190147a595a5f26/dfm_tools/modplot.py
1568+
"""
1569+
from .plot_types.curved_quiver import CurvedQuiverSolver, CurvedQuiverSet
1570+
1571+
# Parse inputs
1572+
arrowsize = _not_none(arrowsize, rc["curved_quiver.arrowsize"])
1573+
arrowstyle = _not_none(arrowstyle, rc["curved_quiver.arrowstyle"])
1574+
zorder = _not_none(zorder, mlines.Line2D.zorder)
1575+
transform = _not_none(transform, self.transData)
1576+
color = _not_none(color, self._get_lines.get_next_color())
1577+
linewidth = _not_none(linewidth, rc["lines.linewidth"])
1578+
scale = _not_none(scale, rc["curved_quiver.scale"])
1579+
grains = _not_none(grains, rc["curved_quiver.grains"])
1580+
density = _not_none(density, rc["curved_quiver.density"])
1581+
arrows_at_end = _not_none(arrow_at_end, rc["curved_quiver.arrows_at_end"])
1582+
1583+
solver = CurvedQuiverSolver(x, y, density)
1584+
if zorder is None:
1585+
zorder = mlines.Line2D.zorder
1586+
1587+
line_kw = {}
1588+
arrow_kw = dict(arrowstyle=arrowstyle, mutation_scale=10 * arrowsize)
1589+
1590+
use_multicolor_lines = isinstance(color, np.ndarray)
1591+
if use_multicolor_lines:
1592+
if color.shape != solver.grid.shape:
1593+
raise ValueError(
1594+
"If 'color' is given, must have the shape of 'Grid(x,y)'"
1595+
)
1596+
line_colors = []
1597+
color = np.ma.masked_invalid(color)
1598+
else:
1599+
line_kw["color"] = color
1600+
arrow_kw["color"] = color
15081601

1509-
See also
1510-
--------
1511-
matplotlib.axes.Axes
1512-
ultraplot.axes.Axes
1513-
ultraplot.axes.CartesianAxes
1514-
ultraplot.axes.PolarAxes
1515-
ultraplot.axes.GeoAxes
1516-
"""
1517-
super().__init__(*args, **kwargs)
1602+
if isinstance(linewidth, np.ndarray):
1603+
if linewidth.shape != solver.grid.shape:
1604+
raise ValueError(
1605+
"If 'linewidth' is given, must have the shape of 'Grid(x,y)'"
1606+
)
1607+
line_kw["linewidth"] = []
1608+
else:
1609+
line_kw["linewidth"] = linewidth
1610+
arrow_kw["linewidth"] = linewidth
1611+
1612+
line_kw["zorder"] = zorder
1613+
arrow_kw["zorder"] = zorder
1614+
1615+
## Sanity checks.
1616+
if u.shape != solver.grid.shape or v.shape != solver.grid.shape:
1617+
raise ValueError("'u' and 'v' must be of shape 'Grid(x,y)'")
1618+
1619+
u = np.ma.masked_invalid(u)
1620+
v = np.ma.masked_invalid(v)
1621+
magnitude = np.sqrt(u**2 + v**2)
1622+
magnitude /= np.max(magnitude)
1623+
1624+
resolution = scale / grains
1625+
minlength = 0.9 * resolution
1626+
1627+
integrate = solver.get_integrator(u, v, minlength, resolution, magnitude)
1628+
trajectories = []
1629+
edges = []
1630+
1631+
if start_points is None:
1632+
start_points = solver.gen_starting_points(x, y, grains)
1633+
1634+
sp2 = np.asanyarray(start_points, dtype=float).copy()
1635+
1636+
# Check if start_points are outside the data boundaries
1637+
for xs, ys in sp2:
1638+
if not (
1639+
solver.grid.x_origin <= xs <= solver.grid.x_origin + solver.grid.width
1640+
and solver.grid.y_origin
1641+
<= ys
1642+
<= solver.grid.y_origin + solver.grid.height
1643+
):
1644+
raise ValueError(
1645+
"Starting point ({}, {}) outside of data "
1646+
"boundaries".format(xs, ys)
1647+
)
1648+
1649+
if use_multicolor_lines:
1650+
if norm is None:
1651+
norm = mcolors.Normalize(color.min(), color.max())
1652+
if cmap is None:
1653+
cmap = constructor.Colormap(rc["image.cmap"])
1654+
else:
1655+
cmap = mcm.get_cmap(cmap)
1656+
1657+
# Convert start_points from data to array coords
1658+
# Shift the seed points from the bottom left of the data so that
1659+
# data2grid works properly.
1660+
sp2[:, 0] -= solver.grid.x_origin
1661+
sp2[:, 1] -= solver.grid.y_origin
1662+
1663+
for xs, ys in sp2:
1664+
xg, yg = solver.domain_map.data2grid(xs, ys)
1665+
t = integrate(xg, yg)
1666+
if t is not None:
1667+
trajectories.append(t[0])
1668+
edges.append(t[1])
1669+
streamlines = []
1670+
arrows = []
1671+
for t, edge in zip(trajectories, edges):
1672+
tgx = np.array(t[0])
1673+
tgy = np.array(t[1])
1674+
1675+
# Rescale from grid-coordinates to data-coordinates.
1676+
tx, ty = solver.domain_map.grid2data(*np.array(t))
1677+
tx += solver.grid.x_origin
1678+
ty += solver.grid.y_origin
1679+
1680+
points = np.transpose([tx, ty]).reshape(-1, 1, 2)
1681+
streamlines.extend(np.hstack([points[:-1], points[1:]]))
1682+
1683+
if len(tx) < 2:
1684+
continue
1685+
1686+
# Add arrows
1687+
s = np.cumsum(np.sqrt(np.diff(tx) ** 2 + np.diff(ty) ** 2))
1688+
if arrow_at_end:
1689+
if len(tx) < 2:
1690+
continue
1691+
1692+
arrow_tail = (tx[-1], ty[-1])
1693+
1694+
# Extrapolate to find arrow head
1695+
xg, yg = solver.domain_map.data2grid(
1696+
tx[-1] - solver.grid.x_origin, ty[-1] - solver.grid.y_origin
1697+
)
1698+
1699+
ui = solver.interpgrid(u, xg, yg)
1700+
vi = solver.interpgrid(v, xg, yg)
1701+
1702+
norm_v = np.sqrt(ui**2 + vi**2)
1703+
if norm_v > 0:
1704+
ui /= norm_v
1705+
vi /= norm_v
1706+
1707+
if len(s) > 0:
1708+
# use average segment length
1709+
arrow_length = arrowsize * (s[-1] / len(s))
1710+
else:
1711+
# fallback for very short streamlines
1712+
arrow_length = (
1713+
arrowsize * 0.1 * np.mean([solver.grid.dx, solver.grid.dy])
1714+
)
1715+
1716+
arrow_head = (tx[-1] + ui * arrow_length, ty[-1] + vi * arrow_length)
1717+
n = len(s) - 1 if len(s) > 0 else 0
1718+
else:
1719+
n = np.searchsorted(s, s[-1] / 2.0)
1720+
arrow_tail = (tx[n], ty[n])
1721+
arrow_head = (np.mean(tx[n : n + 2]), np.mean(ty[n : n + 2]))
1722+
1723+
if isinstance(linewidth, np.ndarray):
1724+
line_widths = solver.interpgrid(linewidth, tgx, tgy)[:-1]
1725+
line_kw["linewidth"].extend(line_widths)
1726+
arrow_kw["linewidth"] = line_widths[n]
1727+
1728+
if use_multicolor_lines:
1729+
color_values = solver.interpgrid(color, tgx, tgy)[:-1]
1730+
line_colors.append(color_values)
1731+
arrow_kw["color"] = cmap(norm(color_values[n]))
1732+
1733+
if not edge:
1734+
p = mpatches.FancyArrowPatch(
1735+
arrow_tail, arrow_head, transform=transform, **arrow_kw
1736+
)
1737+
else:
1738+
continue
1739+
1740+
ds = np.sqrt(
1741+
(arrow_tail[0] - arrow_head[0]) ** 2
1742+
+ (arrow_tail[1] - arrow_head[1]) ** 2
1743+
)
1744+
if ds < 1e-15:
1745+
continue # remove vanishingly short arrows that cause Patch to fail
1746+
1747+
self.add_patch(p)
1748+
arrows.append(p)
1749+
1750+
lc = mcollections.LineCollection(streamlines, transform=transform, **line_kw)
1751+
lc.sticky_edges.x[:] = [
1752+
solver.grid.x_origin,
1753+
solver.grid.x_origin + solver.grid.width,
1754+
]
1755+
lc.sticky_edges.y[:] = [
1756+
solver.grid.y_origin,
1757+
solver.grid.y_origin + solver.grid.height,
1758+
]
1759+
1760+
if use_multicolor_lines:
1761+
lc.set_array(np.ma.hstack(line_colors))
1762+
lc.set_cmap(cmap)
1763+
lc.set_norm(norm)
1764+
1765+
self.add_collection(lc)
1766+
self.autoscale_view()
1767+
1768+
ac = mcollections.PatchCollection(arrows)
1769+
stream_container = CurvedQuiverSet(lc, ac)
1770+
return stream_container
15181771

15191772
def _call_native(self, name, *args, **kwargs):
15201773
"""
@@ -5359,6 +5612,7 @@ def tripcolor(self, *args, **kwargs):
53595612

53605613
# Update kwargs and handle cmap
53615614
kw.update(_pop_props(kw, "collection"))
5615+
53625616
center_levels = kw.pop("center_levels", None)
53635617
kw = self._parse_cmap(
53645618
triangulation.x, triangulation.y, z, center_levels=center_levels, **kw

ultraplot/axes/plot_types/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)