|
12 | 12 |
|
13 | 13 | from typing import Any, Union, Iterable, Optional |
14 | 14 |
|
15 | | -from typing import Any, Union |
16 | 15 | from collections.abc import Callable |
17 | 16 | from collections.abc import Iterable |
18 | 17 |
|
|
33 | 32 | import matplotlib as mpl |
34 | 33 | from packaging import version |
35 | 34 | import numpy as np |
| 35 | +from typing import Optional, Union, Any |
36 | 36 | import numpy.ma as ma |
37 | 37 |
|
38 | 38 | from .. import colors as pcolors |
|
170 | 170 | docstring._snippet_manager["plot.args_1d_shared"] = _args_1d_shared_docstring |
171 | 171 | docstring._snippet_manager["plot.args_2d_shared"] = _args_2d_shared_docstring |
172 | 172 |
|
| 173 | +_curved_quiver_docstring = """ |
| 174 | +Draws curved vector field arrows (streamlines with arrows) for 2D vector fields. |
173 | 175 |
|
| 176 | +Parameters |
| 177 | +---------- |
| 178 | +x, y : 1D or 2D arrays |
| 179 | + Grid coordinates. |
| 180 | +u, v : 2D arrays |
| 181 | + Vector components. |
| 182 | +color : color or 2D array, optional |
| 183 | + Streamline color. |
| 184 | +density : float or (float, float), optional |
| 185 | + Controls the closeness of streamlines. |
| 186 | +grains : int or (int, int), optional |
| 187 | + Number of seed points in x and y. |
| 188 | +linewidth : float or 2D array, optional |
| 189 | + Width of streamlines. |
| 190 | +cmap, norm : optional |
| 191 | + Colormap and normalization for array colors. |
| 192 | +arrowsize : float, optional |
| 193 | + Arrow size scaling. |
| 194 | +arrowstyle : str, optional |
| 195 | + Arrow style specification. |
| 196 | +transform : optional |
| 197 | + Matplotlib transform. |
| 198 | +zorder : float, optional |
| 199 | + Z-order for lines/arrows. |
| 200 | +start_points : (N, 2) array, optional |
| 201 | + Starting points for streamlines. |
| 202 | +
|
| 203 | +Returns |
| 204 | +------- |
| 205 | +CurvedQuiverSet |
| 206 | + Container with attributes: |
| 207 | + - lines: LineCollection of streamlines |
| 208 | + - arrows: PatchCollection of arrows |
| 209 | +""" |
| 210 | + |
| 211 | +docstring._snippet_manager["plot.curved_quiver"] = _curved_quiver_docstring |
174 | 212 | # Auto colorbar and legend docstring |
175 | 213 | _guide_docstring = """ |
176 | 214 | colorbar : bool, int, or str, optional |
@@ -1499,22 +1537,237 @@ class PlotAxes(base.Axes): |
1499 | 1537 | Implements all plotting overrides. |
1500 | 1538 | """ |
1501 | 1539 |
|
1502 | | - def __init__(self, *args, **kwargs): |
| 1540 | + @docstring._snippet_manager |
| 1541 | + def curved_quiver( |
| 1542 | + self, |
| 1543 | + x: np.ndarray, |
| 1544 | + y: np.ndarray, |
| 1545 | + u: np.ndarray, |
| 1546 | + v: np.ndarray, |
| 1547 | + linewidth: Optional[float] = None, |
| 1548 | + color: Optional[Union[str, Any]] = None, |
| 1549 | + cmap: Optional[Any] = None, |
| 1550 | + norm: Optional[Any] = None, |
| 1551 | + arrowsize: Optional[float] = None, |
| 1552 | + arrowstyle: Optional[str] = None, |
| 1553 | + transform: Optional[Any] = None, |
| 1554 | + zorder: Optional[int] = None, |
| 1555 | + start_points: Optional[np.ndarray] = None, |
| 1556 | + scale: Optional[float] = None, |
| 1557 | + grains: Optional[int] = None, |
| 1558 | + density: Optional[int] = None, |
| 1559 | + arrow_at_end: Optional[bool] = None, |
| 1560 | + ): |
1503 | 1561 | """ |
1504 | | - Parameters |
1505 | | - ---------- |
1506 | | - *args, **kwargs |
1507 | | - Passed to `ultraplot.axes.Axes`. |
| 1562 | + %(plot.curved_quiver)s |
| 1563 | +
|
| 1564 | + Notes |
| 1565 | + ----- |
| 1566 | + The implementation of this function is based on the `dfm_tools` repository. |
| 1567 | + Original file: https://github.com/Deltares/dfm_tools/blob/829e76f48ebc42460aae118cc190147a595a5f26/dfm_tools/modplot.py |
| 1568 | + """ |
| 1569 | + from .plot_types.curved_quiver import CurvedQuiverSolver, CurvedQuiverSet |
| 1570 | + |
| 1571 | + # Parse inputs |
| 1572 | + arrowsize = _not_none(arrowsize, rc["curved_quiver.arrowsize"]) |
| 1573 | + arrowstyle = _not_none(arrowstyle, rc["curved_quiver.arrowstyle"]) |
| 1574 | + zorder = _not_none(zorder, mlines.Line2D.zorder) |
| 1575 | + transform = _not_none(transform, self.transData) |
| 1576 | + color = _not_none(color, self._get_lines.get_next_color()) |
| 1577 | + linewidth = _not_none(linewidth, rc["lines.linewidth"]) |
| 1578 | + scale = _not_none(scale, rc["curved_quiver.scale"]) |
| 1579 | + grains = _not_none(grains, rc["curved_quiver.grains"]) |
| 1580 | + density = _not_none(density, rc["curved_quiver.density"]) |
| 1581 | + arrows_at_end = _not_none(arrow_at_end, rc["curved_quiver.arrows_at_end"]) |
| 1582 | + |
| 1583 | + solver = CurvedQuiverSolver(x, y, density) |
| 1584 | + if zorder is None: |
| 1585 | + zorder = mlines.Line2D.zorder |
| 1586 | + |
| 1587 | + line_kw = {} |
| 1588 | + arrow_kw = dict(arrowstyle=arrowstyle, mutation_scale=10 * arrowsize) |
| 1589 | + |
| 1590 | + use_multicolor_lines = isinstance(color, np.ndarray) |
| 1591 | + if use_multicolor_lines: |
| 1592 | + if color.shape != solver.grid.shape: |
| 1593 | + raise ValueError( |
| 1594 | + "If 'color' is given, must have the shape of 'Grid(x,y)'" |
| 1595 | + ) |
| 1596 | + line_colors = [] |
| 1597 | + color = np.ma.masked_invalid(color) |
| 1598 | + else: |
| 1599 | + line_kw["color"] = color |
| 1600 | + arrow_kw["color"] = color |
1508 | 1601 |
|
1509 | | - See also |
1510 | | - -------- |
1511 | | - matplotlib.axes.Axes |
1512 | | - ultraplot.axes.Axes |
1513 | | - ultraplot.axes.CartesianAxes |
1514 | | - ultraplot.axes.PolarAxes |
1515 | | - ultraplot.axes.GeoAxes |
1516 | | - """ |
1517 | | - super().__init__(*args, **kwargs) |
| 1602 | + if isinstance(linewidth, np.ndarray): |
| 1603 | + if linewidth.shape != solver.grid.shape: |
| 1604 | + raise ValueError( |
| 1605 | + "If 'linewidth' is given, must have the shape of 'Grid(x,y)'" |
| 1606 | + ) |
| 1607 | + line_kw["linewidth"] = [] |
| 1608 | + else: |
| 1609 | + line_kw["linewidth"] = linewidth |
| 1610 | + arrow_kw["linewidth"] = linewidth |
| 1611 | + |
| 1612 | + line_kw["zorder"] = zorder |
| 1613 | + arrow_kw["zorder"] = zorder |
| 1614 | + |
| 1615 | + ## Sanity checks. |
| 1616 | + if u.shape != solver.grid.shape or v.shape != solver.grid.shape: |
| 1617 | + raise ValueError("'u' and 'v' must be of shape 'Grid(x,y)'") |
| 1618 | + |
| 1619 | + u = np.ma.masked_invalid(u) |
| 1620 | + v = np.ma.masked_invalid(v) |
| 1621 | + magnitude = np.sqrt(u**2 + v**2) |
| 1622 | + magnitude /= np.max(magnitude) |
| 1623 | + |
| 1624 | + resolution = scale / grains |
| 1625 | + minlength = 0.9 * resolution |
| 1626 | + |
| 1627 | + integrate = solver.get_integrator(u, v, minlength, resolution, magnitude) |
| 1628 | + trajectories = [] |
| 1629 | + edges = [] |
| 1630 | + |
| 1631 | + if start_points is None: |
| 1632 | + start_points = solver.gen_starting_points(x, y, grains) |
| 1633 | + |
| 1634 | + sp2 = np.asanyarray(start_points, dtype=float).copy() |
| 1635 | + |
| 1636 | + # Check if start_points are outside the data boundaries |
| 1637 | + for xs, ys in sp2: |
| 1638 | + if not ( |
| 1639 | + solver.grid.x_origin <= xs <= solver.grid.x_origin + solver.grid.width |
| 1640 | + and solver.grid.y_origin |
| 1641 | + <= ys |
| 1642 | + <= solver.grid.y_origin + solver.grid.height |
| 1643 | + ): |
| 1644 | + raise ValueError( |
| 1645 | + "Starting point ({}, {}) outside of data " |
| 1646 | + "boundaries".format(xs, ys) |
| 1647 | + ) |
| 1648 | + |
| 1649 | + if use_multicolor_lines: |
| 1650 | + if norm is None: |
| 1651 | + norm = mcolors.Normalize(color.min(), color.max()) |
| 1652 | + if cmap is None: |
| 1653 | + cmap = constructor.Colormap(rc["image.cmap"]) |
| 1654 | + else: |
| 1655 | + cmap = mcm.get_cmap(cmap) |
| 1656 | + |
| 1657 | + # Convert start_points from data to array coords |
| 1658 | + # Shift the seed points from the bottom left of the data so that |
| 1659 | + # data2grid works properly. |
| 1660 | + sp2[:, 0] -= solver.grid.x_origin |
| 1661 | + sp2[:, 1] -= solver.grid.y_origin |
| 1662 | + |
| 1663 | + for xs, ys in sp2: |
| 1664 | + xg, yg = solver.domain_map.data2grid(xs, ys) |
| 1665 | + t = integrate(xg, yg) |
| 1666 | + if t is not None: |
| 1667 | + trajectories.append(t[0]) |
| 1668 | + edges.append(t[1]) |
| 1669 | + streamlines = [] |
| 1670 | + arrows = [] |
| 1671 | + for t, edge in zip(trajectories, edges): |
| 1672 | + tgx = np.array(t[0]) |
| 1673 | + tgy = np.array(t[1]) |
| 1674 | + |
| 1675 | + # Rescale from grid-coordinates to data-coordinates. |
| 1676 | + tx, ty = solver.domain_map.grid2data(*np.array(t)) |
| 1677 | + tx += solver.grid.x_origin |
| 1678 | + ty += solver.grid.y_origin |
| 1679 | + |
| 1680 | + points = np.transpose([tx, ty]).reshape(-1, 1, 2) |
| 1681 | + streamlines.extend(np.hstack([points[:-1], points[1:]])) |
| 1682 | + |
| 1683 | + if len(tx) < 2: |
| 1684 | + continue |
| 1685 | + |
| 1686 | + # Add arrows |
| 1687 | + s = np.cumsum(np.sqrt(np.diff(tx) ** 2 + np.diff(ty) ** 2)) |
| 1688 | + if arrow_at_end: |
| 1689 | + if len(tx) < 2: |
| 1690 | + continue |
| 1691 | + |
| 1692 | + arrow_tail = (tx[-1], ty[-1]) |
| 1693 | + |
| 1694 | + # Extrapolate to find arrow head |
| 1695 | + xg, yg = solver.domain_map.data2grid( |
| 1696 | + tx[-1] - solver.grid.x_origin, ty[-1] - solver.grid.y_origin |
| 1697 | + ) |
| 1698 | + |
| 1699 | + ui = solver.interpgrid(u, xg, yg) |
| 1700 | + vi = solver.interpgrid(v, xg, yg) |
| 1701 | + |
| 1702 | + norm_v = np.sqrt(ui**2 + vi**2) |
| 1703 | + if norm_v > 0: |
| 1704 | + ui /= norm_v |
| 1705 | + vi /= norm_v |
| 1706 | + |
| 1707 | + if len(s) > 0: |
| 1708 | + # use average segment length |
| 1709 | + arrow_length = arrowsize * (s[-1] / len(s)) |
| 1710 | + else: |
| 1711 | + # fallback for very short streamlines |
| 1712 | + arrow_length = ( |
| 1713 | + arrowsize * 0.1 * np.mean([solver.grid.dx, solver.grid.dy]) |
| 1714 | + ) |
| 1715 | + |
| 1716 | + arrow_head = (tx[-1] + ui * arrow_length, ty[-1] + vi * arrow_length) |
| 1717 | + n = len(s) - 1 if len(s) > 0 else 0 |
| 1718 | + else: |
| 1719 | + n = np.searchsorted(s, s[-1] / 2.0) |
| 1720 | + arrow_tail = (tx[n], ty[n]) |
| 1721 | + arrow_head = (np.mean(tx[n : n + 2]), np.mean(ty[n : n + 2])) |
| 1722 | + |
| 1723 | + if isinstance(linewidth, np.ndarray): |
| 1724 | + line_widths = solver.interpgrid(linewidth, tgx, tgy)[:-1] |
| 1725 | + line_kw["linewidth"].extend(line_widths) |
| 1726 | + arrow_kw["linewidth"] = line_widths[n] |
| 1727 | + |
| 1728 | + if use_multicolor_lines: |
| 1729 | + color_values = solver.interpgrid(color, tgx, tgy)[:-1] |
| 1730 | + line_colors.append(color_values) |
| 1731 | + arrow_kw["color"] = cmap(norm(color_values[n])) |
| 1732 | + |
| 1733 | + if not edge: |
| 1734 | + p = mpatches.FancyArrowPatch( |
| 1735 | + arrow_tail, arrow_head, transform=transform, **arrow_kw |
| 1736 | + ) |
| 1737 | + else: |
| 1738 | + continue |
| 1739 | + |
| 1740 | + ds = np.sqrt( |
| 1741 | + (arrow_tail[0] - arrow_head[0]) ** 2 |
| 1742 | + + (arrow_tail[1] - arrow_head[1]) ** 2 |
| 1743 | + ) |
| 1744 | + if ds < 1e-15: |
| 1745 | + continue # remove vanishingly short arrows that cause Patch to fail |
| 1746 | + |
| 1747 | + self.add_patch(p) |
| 1748 | + arrows.append(p) |
| 1749 | + |
| 1750 | + lc = mcollections.LineCollection(streamlines, transform=transform, **line_kw) |
| 1751 | + lc.sticky_edges.x[:] = [ |
| 1752 | + solver.grid.x_origin, |
| 1753 | + solver.grid.x_origin + solver.grid.width, |
| 1754 | + ] |
| 1755 | + lc.sticky_edges.y[:] = [ |
| 1756 | + solver.grid.y_origin, |
| 1757 | + solver.grid.y_origin + solver.grid.height, |
| 1758 | + ] |
| 1759 | + |
| 1760 | + if use_multicolor_lines: |
| 1761 | + lc.set_array(np.ma.hstack(line_colors)) |
| 1762 | + lc.set_cmap(cmap) |
| 1763 | + lc.set_norm(norm) |
| 1764 | + |
| 1765 | + self.add_collection(lc) |
| 1766 | + self.autoscale_view() |
| 1767 | + |
| 1768 | + ac = mcollections.PatchCollection(arrows) |
| 1769 | + stream_container = CurvedQuiverSet(lc, ac) |
| 1770 | + return stream_container |
1518 | 1771 |
|
1519 | 1772 | def _call_native(self, name, *args, **kwargs): |
1520 | 1773 | """ |
@@ -5359,6 +5612,7 @@ def tripcolor(self, *args, **kwargs): |
5359 | 5612 |
|
5360 | 5613 | # Update kwargs and handle cmap |
5361 | 5614 | kw.update(_pop_props(kw, "collection")) |
| 5615 | + |
5362 | 5616 | center_levels = kw.pop("center_levels", None) |
5363 | 5617 | kw = self._parse_cmap( |
5364 | 5618 | triangulation.x, triangulation.y, z, center_levels=center_levels, **kw |
|
0 commit comments