Skip to content

Commit a17a00f

Browse files
authored
Add typing to some interval functions (#6862)
1 parent 670d14d commit a17a00f

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

xarray/plot/utils.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import warnings
66
from datetime import datetime
77
from inspect import getfullargspec
8-
from typing import Any, Iterable, Mapping
8+
from typing import Any, Iterable, Mapping, Sequence
99

1010
import numpy as np
1111
import pandas as pd
@@ -502,7 +502,7 @@ def label_from_attrs(da, extra: str = "") -> str:
502502
return "\n".join(textwrap.wrap(name + extra + units, 30))
503503

504504

505-
def _interval_to_mid_points(array):
505+
def _interval_to_mid_points(array: Iterable[pd.Interval]) -> np.ndarray:
506506
"""
507507
Helper function which returns an array
508508
with the Intervals' mid points.
@@ -511,7 +511,7 @@ def _interval_to_mid_points(array):
511511
return np.array([x.mid for x in array])
512512

513513

514-
def _interval_to_bound_points(array):
514+
def _interval_to_bound_points(array: Sequence[pd.Interval]) -> np.ndarray:
515515
"""
516516
Helper function which returns an array
517517
with the Intervals' boundaries.
@@ -523,7 +523,9 @@ def _interval_to_bound_points(array):
523523
return array_boundaries
524524

525525

526-
def _interval_to_double_bound_points(xarray, yarray):
526+
def _interval_to_double_bound_points(
527+
xarray: Iterable[pd.Interval], yarray: Iterable
528+
) -> tuple[np.ndarray, np.ndarray]:
527529
"""
528530
Helper function to deal with a xarray consisting of pd.Intervals. Each
529531
interval is replaced with both boundaries. I.e. the length of xarray
@@ -533,13 +535,15 @@ def _interval_to_double_bound_points(xarray, yarray):
533535
xarray1 = np.array([x.left for x in xarray])
534536
xarray2 = np.array([x.right for x in xarray])
535537

536-
xarray = list(itertools.chain.from_iterable(zip(xarray1, xarray2)))
537-
yarray = list(itertools.chain.from_iterable(zip(yarray, yarray)))
538+
xarray_out = np.array(list(itertools.chain.from_iterable(zip(xarray1, xarray2))))
539+
yarray_out = np.array(list(itertools.chain.from_iterable(zip(yarray, yarray))))
538540

539-
return xarray, yarray
541+
return xarray_out, yarray_out
540542

541543

542-
def _resolve_intervals_1dplot(xval, yval, kwargs):
544+
def _resolve_intervals_1dplot(
545+
xval: np.ndarray, yval: np.ndarray, kwargs: dict
546+
) -> tuple[np.ndarray, np.ndarray, str, str, dict]:
543547
"""
544548
Helper function to replace the values of x and/or y coordinate arrays
545549
containing pd.Interval with their mid-points or - for step plots - double

0 commit comments

Comments
 (0)