Skip to content

Commit e517dcc

Browse files
committed
Trying to resolve mypy issues
1 parent 93d2abc commit e517dcc

File tree

7 files changed

+62
-38
lines changed

7 files changed

+62
-38
lines changed

xarray/core/alignment.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ class Aligner(Generic[T_Alignable]):
113113
objects: tuple[T_Alignable, ...]
114114
results: tuple[T_Alignable, ...]
115115
objects_matching_indexes: tuple[dict[MatchingIndexKey, Index], ...]
116-
join: str
116+
join: str | CombineKwargDefault
117117
exclude_dims: frozenset[Hashable]
118118
exclude_vars: frozenset[Hashable]
119119
copy: bool
@@ -133,7 +133,7 @@ class Aligner(Generic[T_Alignable]):
133133
def __init__(
134134
self,
135135
objects: Iterable[T_Alignable],
136-
join: str = "inner",
136+
join: str | CombineKwargDefault = "inner",
137137
indexes: Mapping[Any, Any] | None = None,
138138
exclude_dims: str | Iterable[Hashable] = frozenset(),
139139
exclude_vars: Iterable[Hashable] = frozenset(),
@@ -146,7 +146,14 @@ def __init__(
146146
self.objects = tuple(objects)
147147
self.objects_matching_indexes = ()
148148

149-
if join not in ["inner", "outer", "override", "exact", "left", "right"]:
149+
if not isinstance(join, CombineKwargDefault) and join not in [
150+
"inner",
151+
"outer",
152+
"override",
153+
"exact",
154+
"left",
155+
"right",
156+
]:
150157
raise ValueError(f"invalid value for join: {join}")
151158
self.join = join
152159

@@ -618,7 +625,7 @@ def align(
618625
obj1: T_Obj1,
619626
/,
620627
*,
621-
join: JoinOptions = "inner",
628+
join: JoinOptions | CombineKwargDefault = "inner",
622629
copy: bool = True,
623630
indexes=None,
624631
exclude: str | Iterable[Hashable] = frozenset(),
@@ -632,7 +639,7 @@ def align(
632639
obj2: T_Obj2,
633640
/,
634641
*,
635-
join: JoinOptions = "inner",
642+
join: JoinOptions | CombineKwargDefault = "inner",
636643
copy: bool = True,
637644
indexes=None,
638645
exclude: str | Iterable[Hashable] = frozenset(),
@@ -647,7 +654,7 @@ def align(
647654
obj3: T_Obj3,
648655
/,
649656
*,
650-
join: JoinOptions = "inner",
657+
join: JoinOptions | CombineKwargDefault = "inner",
651658
copy: bool = True,
652659
indexes=None,
653660
exclude: str | Iterable[Hashable] = frozenset(),
@@ -663,7 +670,7 @@ def align(
663670
obj4: T_Obj4,
664671
/,
665672
*,
666-
join: JoinOptions = "inner",
673+
join: JoinOptions | CombineKwargDefault = "inner",
667674
copy: bool = True,
668675
indexes=None,
669676
exclude: str | Iterable[Hashable] = frozenset(),
@@ -680,7 +687,7 @@ def align(
680687
obj5: T_Obj5,
681688
/,
682689
*,
683-
join: JoinOptions = "inner",
690+
join: JoinOptions | CombineKwargDefault = "inner",
684691
copy: bool = True,
685692
indexes=None,
686693
exclude: str | Iterable[Hashable] = frozenset(),
@@ -691,7 +698,7 @@ def align(
691698
@overload
692699
def align(
693700
*objects: T_Alignable,
694-
join: JoinOptions = "inner",
701+
join: JoinOptions | CombineKwargDefault = "inner",
695702
copy: bool = True,
696703
indexes=None,
697704
exclude: str | Iterable[Hashable] = frozenset(),
@@ -701,7 +708,7 @@ def align(
701708

702709
def align(
703710
*objects: T_Alignable,
704-
join: JoinOptions = "inner",
711+
join: JoinOptions | CombineKwargDefault = "inner",
705712
copy: bool = True,
706713
indexes=None,
707714
exclude: str | Iterable[Hashable] = frozenset(),

xarray/core/combine.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def _combine_all_along_first_dim(
269269
dim,
270270
data_vars,
271271
coords,
272-
compat: CompatOptions,
272+
compat: CompatOptions | CombineKwargDefault,
273273
fill_value,
274274
join: JoinOptions | CombineKwargDefault,
275275
combine_attrs: CombineAttrsOptions,
@@ -298,7 +298,7 @@ def _combine_all_along_first_dim(
298298
def _combine_1d(
299299
datasets,
300300
concat_dim,
301-
compat: CompatOptions,
301+
compat: CompatOptions | CombineKwargDefault,
302302
data_vars,
303303
coords,
304304
fill_value,
@@ -365,7 +365,7 @@ def _nested_combine(
365365
return Dataset()
366366

367367
if isinstance(concat_dim, str | DataArray) or concat_dim is None:
368-
concat_dim = [concat_dim] # type: ignore[assignment]
368+
concat_dim = [concat_dim]
369369

370370
# Arrange datasets for concatenation
371371
# Use information from the shape of the user input

xarray/core/concat.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def concat(
6565
compat: CompatOptions | CombineKwargDefault = _COMPAT_CONCAT_DEFAULT,
6666
positions: Iterable[Iterable[int]] | None = None,
6767
fill_value: object = dtypes.NA,
68-
join: JoinOptions | None = None,
68+
join: JoinOptions | CombineKwargDefault = _JOIN_DEFAULT,
6969
combine_attrs: CombineAttrsOptions = "override",
7070
create_index_for_new_dim: bool = True,
7171
) -> T_DataArray: ...
@@ -334,7 +334,7 @@ def _calc_concat_over(
334334
datasets,
335335
dim,
336336
dim_names,
337-
data_vars: T_DataVars,
337+
data_vars: T_DataVars | CombineKwargDefault,
338338
coords,
339339
compat,
340340
):
@@ -485,7 +485,7 @@ def process_subset_opt(opt, subset):
485485
)
486486
concat_over.update(opt)
487487

488-
warnings = []
488+
warnings: list[str] = []
489489
process_subset_opt(data_vars, "data_vars")
490490
process_subset_opt(coords, "coords")
491491

@@ -534,7 +534,7 @@ def _dataset_concat(
534534
datasets: Iterable[T_Dataset],
535535
dim: str | T_Variable | T_DataArray | pd.Index,
536536
data_vars: T_DataVars | CombineKwargDefault,
537-
coords: str | list[str] | CombineKwargDefault,
537+
coords: str | list[Hashable] | CombineKwargDefault,
538538
compat: CompatOptions | CombineKwargDefault,
539539
positions: Iterable[Iterable[int]] | None,
540540
fill_value: Any,
@@ -780,7 +780,7 @@ def _dataarray_concat(
780780
arrays: Iterable[T_DataArray],
781781
dim: str | T_Variable | T_DataArray | pd.Index,
782782
data_vars: T_DataVars | CombineKwargDefault,
783-
coords: str | list[str] | CombineKwargDefault,
783+
coords: str | list[Hashable] | CombineKwargDefault,
784784
compat: CompatOptions | CombineKwargDefault,
785785
positions: Iterable[Iterable[int]] | None,
786786
fill_value: object,

xarray/core/merge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ def coerce_pandas_values(objects: Iterable[CoercibleMapping]) -> list[DatasetLik
527527
def _get_priority_vars_and_indexes(
528528
objects: Sequence[DatasetLike],
529529
priority_arg: int | None,
530-
compat: CompatOptions = "equals",
530+
compat: CompatOptions | CombineKwargDefault = "equals",
531531
) -> dict[Hashable, MergeElement]:
532532
"""Extract the priority variable from a list of mappings.
533533

xarray/tests/test_backends.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import uuid
1515
import warnings
1616
from collections.abc import Generator, Iterator, Mapping
17-
from contextlib import ExitStack, nullcontext
17+
from contextlib import ExitStack
1818
from io import BytesIO
1919
from os import listdir
2020
from pathlib import Path
@@ -4614,12 +4614,16 @@ def test_open_mfdataset_dataset_combine_attrs(
46144614
with set_options(
46154615
use_new_combine_kwarg_defaults=use_new_combine_kwarg_defaults
46164616
):
4617-
warning = (
4617+
warning: contextlib.AbstractContextManager = (
46184618
pytest.warns(FutureWarning)
46194619
if not use_new_combine_kwarg_defaults
4620-
else nullcontext()
4620+
else contextlib.nullcontext()
4621+
)
4622+
error: contextlib.AbstractContextManager = (
4623+
pytest.raises(xr.MergeError)
4624+
if expect_error
4625+
else contextlib.nullcontext()
46214626
)
4622-
error = pytest.raises(xr.MergeError) if expect_error else nullcontext()
46234627
with warning:
46244628
with error:
46254629
with xr.open_mfdataset(
@@ -4785,13 +4789,15 @@ def test_open_mfdataset_warns_when_kwargs_set_to_different(
47854789
xr.concat([ds1, ds2], dim="t", **kwargs)
47864790

47874791
with set_options(use_new_combine_kwarg_defaults=False):
4788-
if "data_vars" not in kwargs:
4789-
expectation = pytest.warns(
4792+
expectation: contextlib.AbstractContextManager = (
4793+
pytest.warns(
47904794
FutureWarning,
47914795
match="will change from data_vars='all'",
47924796
)
4793-
else:
4794-
expectation = nullcontext()
4797+
if "data_vars" not in kwargs
4798+
else contextlib.nullcontext()
4799+
)
4800+
47954801
with pytest.warns(
47964802
FutureWarning,
47974803
match="will change from compat='equals'",

xarray/tests/test_concat.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from collections.abc import Callable
4-
from contextlib import nullcontext
4+
from contextlib import AbstractContextManager, nullcontext
55
from copy import deepcopy
66
from typing import TYPE_CHECKING, Any, Literal
77

@@ -1476,13 +1476,14 @@ def test_concat_coords_kwarg(
14761476
FutureWarning,
14771477
match="will change from data_vars='all' to data_vars='minimal'",
14781478
):
1479-
if coords == "different":
1480-
expectation = pytest.warns(
1479+
expectation: AbstractContextManager = (
1480+
pytest.warns(
14811481
FutureWarning,
14821482
match="will change from compat='equals' to compat='override'",
14831483
)
1484-
else:
1485-
expectation = nullcontext()
1484+
if coords == "different"
1485+
else nullcontext()
1486+
)
14861487
with expectation:
14871488
old = concat(datasets, data["dim1"], coords=coords)
14881489

xarray/util/deprecation_helpers.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,11 @@ def wrapper(*args, **kwargs):
148148
return wrapper # type: ignore[return-value]
149149

150150

151-
class CombineKwargDefault(ReprObject):
152-
"""Object that handles deprecation cycle for kwarg default values."""
151+
class CombineKwargDefault:
152+
"""Object that handles deprecation cycle for kwarg default values.
153+
154+
Similar to ReprObject
155+
"""
153156

154157
_old: str
155158
_new: str
@@ -160,22 +163,29 @@ def __init__(self, *, name: str, old: str, new: str):
160163
self._old = old
161164
self._new = new
162165

166+
def __repr__(self) -> str:
167+
return self._value
168+
163169
def __eq__(self, other: ReprObject | Any) -> bool:
164-
# TODO: What type can other be? ArrayLike?
165170
return (
166171
self._value == other._value
167172
if isinstance(other, ReprObject)
168173
else self._value == other
169174
)
170175

171176
@property
172-
def _value(self):
177+
def _value(self) -> str:
173178
return self._new if OPTIONS["use_new_combine_kwarg_defaults"] else self._old
174179

175180
def __hash__(self) -> int:
176181
return hash(self._value)
177182

178-
def warning_message(self, message: str, recommend_set_options: bool = True):
183+
def __dask_tokenize__(self) -> object:
184+
from dask.base import normalize_token
185+
186+
return normalize_token((type(self), self._value))
187+
188+
def warning_message(self, message: str, recommend_set_options: bool = True) -> str:
179189
if recommend_set_options:
180190
recommendation = (
181191
" To opt in to new defaults and get rid of these warnings now "
@@ -194,7 +204,7 @@ def warning_message(self, message: str, recommend_set_options: bool = True):
194204
+ recommendation
195205
)
196206

197-
def error_message(self):
207+
def error_message(self) -> str:
198208
return (
199209
f" Error might be related to new default ({self._name}={self._new!r}). "
200210
f"Previously the default was {self._name}={self._old!r}. "

0 commit comments

Comments
 (0)