Skip to content

Commit 06763c2

Browse files
committed
More improvements
1 parent c648cfd commit 06763c2

File tree

1 file changed

+33
-14
lines changed

1 file changed

+33
-14
lines changed

xarray/tests/test_state_machine.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,24 @@
88

99
import xarray.testing.strategies as xrst
1010
from xarray import Dataset
11+
from xarray.indexes import PandasMultiIndex
1112
from xarray.testing import _assert_internal_invariants
1213

1314

15+
def get_not_multiindex_dims(ds: Dataset) -> set:
16+
dims = ds.dims
17+
mindexes = [
18+
name
19+
for name, index in ds.xindexes.items()
20+
if isinstance(index, PandasMultiIndex)
21+
]
22+
return set(dims) - set(mindexes)
23+
24+
25+
def get_dimension_coordinates(ds: Dataset) -> set:
26+
return set(ds.dims) & set(ds._variables)
27+
28+
1429
@st.composite
1530
def unique(draw, strategy):
1631
# https://stackoverflow.com/questions/73737073/create-hypothesis-strategy-that-returns-unique-values
@@ -52,45 +67,49 @@ def add_dim_coord(self, var):
5267
@precondition(lambda self: len(self.dataset.dims) >= 1)
5368
def reset_index(self):
5469
dim = random.choice(tuple(self.dataset.dims))
70+
note(f"> resetting {dim}")
5571
self.dataset = self.dataset.reset_index(dim)
5672

5773
@rule(newname=UNIQUE_NAME)
58-
@precondition(lambda self: len(self.dataset.dims) >= 2)
74+
@precondition(lambda self: len(get_not_multiindex_dims(self.dataset)) >= 2)
5975
def stack(self, newname):
60-
oldnames = random.choices(tuple(self.dataset.dims), k=2)
76+
choices = list(get_not_multiindex_dims(self.dataset))
77+
# cannot stack repeated dims ('0', '0'), so random.choices isn't the best way to choose it
78+
# Instead shuffle and pick the first two.
79+
random.shuffle(choices)
80+
oldnames = choices[:2]
81+
note(f"> stacking {oldnames} as {newname}")
6182
self.dataset = self.dataset.stack({newname: oldnames})
6283

6384
@rule()
6485
def unstack(self):
6586
self.dataset = self.dataset.unstack()
6687

6788
@rule(newname=UNIQUE_NAME)
68-
@precondition(lambda self: len(self.dataset.dims) >= 1)
89+
@precondition(lambda self: bool(get_dimension_coordinates(self.dataset)))
6990
def rename_vars(self, newname):
7091
# benbovy: "skip the default indexes invariant test when the name of an
7192
# existing dimension coordinate is passed as input kwarg or dict key
7293
# to .rename_vars()."
73-
oldname = random.choice(tuple(self.dataset.dims))
94+
95+
oldname = random.choice(tuple(get_dimension_coordinates(self.dataset)))
7496
self.check_default_indexes = False
75-
self.dataset = self.dataset.rename_vars({oldname: newname})
7697
note(f"> renaming {oldname} to {newname}")
98+
self.dataset = self.dataset.rename_vars({oldname: newname})
7799

78100
@rule()
79-
@precondition(
80-
lambda self: (
81-
len(self.dataset._variables) >= 2
82-
and (set(self.dataset.dims) & set(self.dataset._variables))
83-
)
84-
)
101+
@precondition(lambda self: len(self.dataset._variables) >= 2)
102+
@precondition(lambda self: bool(get_dimension_coordinates(self.dataset)))
85103
def swap_dims(self):
86104
ds = self.dataset
87105
# need a dimension coordinate for swapping
88-
dim = random.choice(tuple(set(ds.dims) & set(ds._variables)))
106+
dim = random.choice(tuple(get_dimension_coordinates(ds)))
107+
# Can only swap to a variable with the same dim
89108
to = random.choice(
90-
[name for name, var in ds._variables.items() if var.size == ds.sizes[dim]]
109+
[name for name, var in ds._variables.items() if var.dims == (dim,)]
91110
)
92-
self.dataset = ds.swap_dims({dim: to})
93111
note(f"> swapping {dim} to {to}")
112+
self.dataset = ds.swap_dims({dim: to})
94113

95114
# TODO: enable when we have serializable attrs only
96115
# @rule()

0 commit comments

Comments
 (0)