Skip to content

Commit 0b0fb40

Browse files
authored
Improve alignment typehints (#4522)
* Improve alignment typehints * Fix typing issues * Add a note to what's new
1 parent 881192b commit 0b0fb40

File tree

3 files changed

+23
-6
lines changed

3 files changed

+23
-6
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ New Features
3636
limited regions of existing Zarr stores (:pull:`4035`).
3737
See :ref:`io.zarr.appending` for full details.
3838
By `Stephan Hoyer <https://github.com/shoyer>`_.
39+
- Added typehints in :py:func:`align` to reflect that the same type received in ``objects`` arg will be returned (:pull:`4522`).
40+
By `Michal Baumgartner <https://github.com/m1so>`_.
3941

4042
Bug fixes
4143
~~~~~~~~~

xarray/core/alignment.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,17 @@
22
import operator
33
from collections import defaultdict
44
from contextlib import suppress
5-
from typing import TYPE_CHECKING, Any, Dict, Hashable, Mapping, Optional, Tuple, Union
5+
from typing import (
6+
TYPE_CHECKING,
7+
Any,
8+
Dict,
9+
Hashable,
10+
Mapping,
11+
Optional,
12+
Tuple,
13+
TypeVar,
14+
Union,
15+
)
616

717
import numpy as np
818
import pandas as pd
@@ -13,9 +23,12 @@
1323
from .variable import IndexVariable, Variable
1424

1525
if TYPE_CHECKING:
26+
from .common import DataWithCoords
1627
from .dataarray import DataArray
1728
from .dataset import Dataset
1829

30+
DataAlignable = TypeVar("DataAlignable", bound=DataWithCoords)
31+
1932

2033
def _get_joiner(join):
2134
if join == "outer":
@@ -59,13 +72,13 @@ def _override_indexes(objects, all_indexes, exclude):
5972

6073

6174
def align(
62-
*objects,
75+
*objects: "DataAlignable",
6376
join="inner",
6477
copy=True,
6578
indexes=None,
6679
exclude=frozenset(),
6780
fill_value=dtypes.NA,
68-
):
81+
) -> Tuple["DataAlignable", ...]:
6982
"""
7083
Given any number of Dataset and/or DataArray objects, returns new
7184
objects with aligned indexes and dimension sizes.
@@ -337,7 +350,9 @@ def align(
337350
# fast path for no reindexing necessary
338351
new_obj = obj.copy(deep=copy)
339352
else:
340-
new_obj = obj.reindex(copy=copy, fill_value=fill_value, **valid_indexers)
353+
new_obj = obj.reindex(
354+
copy=copy, fill_value=fill_value, indexers=valid_indexers
355+
)
341356
new_obj.encoding = obj.encoding
342357
result.append(new_obj)
343358

xarray/core/concat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,8 +380,8 @@ def _dataset_concat(
380380
dim, coord = _calc_concat_dim_coord(dim)
381381
# Make sure we're working on a copy (we'll be loading variables)
382382
datasets = [ds.copy() for ds in datasets]
383-
datasets = align(
384-
*datasets, join=join, copy=False, exclude=[dim], fill_value=fill_value
383+
datasets = list(
384+
align(*datasets, join=join, copy=False, exclude=[dim], fill_value=fill_value)
385385
)
386386

387387
dim_coords, dims_sizes, coord_names, data_names = _parse_datasets(datasets)

0 commit comments

Comments
 (0)