Skip to content

Commit b7d6de8

Browse files
Update _validate_against_pure_literal to fail on Xarray objects (#2527)
1 parent baa50ef commit b7d6de8

File tree

3 files changed

+31
-39
lines changed

3 files changed

+31
-39
lines changed

src/parcels/_typing.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,24 +39,27 @@
3939
KernelFunction = Callable[..., None]
4040

4141

42+
def _is_xarray_object(obj): # with no imports
43+
try:
44+
return "xarray.core" in obj.__module__
45+
except AttributeError:
46+
return False
47+
48+
4249
def _validate_against_pure_literal(value, typing_literal):
4350
"""Uses a Literal type alias to validate.
4451
4552
Can't be used with ``Literal[...] | None`` etc. as its not a pure literal.
4653
"""
54+
# TODO remove once https://github.com/pydata/xarray/issues/11209 is resolved - Xarray objects don't work normally in `in` statements
55+
if _is_xarray_object(value):
56+
raise ValueError(f"Invalid input type {type(value)}")
57+
4758
if value not in get_args(typing_literal):
4859
msg = f"Invalid value {value!r}. Valid options are {get_args(typing_literal)!r}"
4960
raise ValueError(msg)
5061

5162

5263
# Assertion functions to clean user input
53-
def assert_valid_interp_method(value: Any):
54-
_validate_against_pure_literal(value, InterpMethodOption)
55-
56-
5764
def assert_valid_mesh(value: Any):
5865
_validate_against_pure_literal(value, Mesh)
59-
60-
61-
def assert_valid_gridindexingtype(value: Any):
62-
_validate_against_pure_literal(value, GridIndexingType)

tests-v3/test_typing.py

Lines changed: 0 additions & 31 deletions
This file was deleted.

tests/test_typing.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import numpy as np
2+
import pytest
3+
import xarray as xr
4+
5+
from parcels._typing import (
6+
assert_valid_mesh,
7+
)
8+
9+
10+
def test_invalid_assert_valid_mesh():
11+
with pytest.raises(ValueError, match="Invalid value"):
12+
assert_valid_mesh("invalid option")
13+
14+
ds = xr.Dataset({"A": (("a", "b"), np.arange(20).reshape(4, 5))})
15+
with pytest.raises(ValueError, match="Invalid input type"):
16+
assert_valid_mesh(ds)
17+
18+
19+
def test_assert_valid_mesh():
20+
assert_valid_mesh("spherical")

0 commit comments

Comments
 (0)