Skip to content

Commit 8ba2f7f

Browse files
committed
Slice items can now be a contextlib.AbstractContextManager; deprecated SliceSource.dispose() and introduced SliceSource.close().
1 parent 70e739c commit 8ba2f7f

File tree

8 files changed

+202
-32
lines changed

8 files changed

+202
-32
lines changed

tests/slice/test_cm.py

Lines changed: 114 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Permissions are hereby granted under the terms of the MIT License:
33
# https://opensource.org/licenses/MIT.
44

5+
import contextlib
56
import shutil
67
import unittest
78
import warnings
@@ -13,6 +14,7 @@
1314
from zappend.fsutil.fileobj import FileObj
1415
from zappend.slice.cm import SliceSourceContextManager
1516
from zappend.slice.cm import open_slice_dataset
17+
from zappend.slice.source import SliceSource
1618
from zappend.slice.sources.memory import MemorySliceSource
1719
from zappend.slice.sources.persistent import PersistentSliceSource
1820
from zappend.slice.sources.temporary import TemporarySliceSource
@@ -23,12 +25,11 @@
2325
# noinspection PyUnusedLocal
2426

2527

26-
# noinspection PyShadowingBuiltins
28+
# noinspection PyShadowingBuiltins,PyRedeclaration
2729
class OpenSliceDatasetTest(unittest.TestCase):
2830
def setUp(self):
2931
clear_memory_fs()
3032

31-
# noinspection PyMethodMayBeStatic
3233
def test_slice_item_is_slice_source(self):
3334
dataset = make_test_dataset()
3435
ctx = Context(dict(target_dir="memory://target.zarr"))
@@ -127,7 +128,6 @@ def test_slice_item_is_uri_with_polling_ok(self):
127128
with slice_cm as slice_ds:
128129
self.assertIsInstance(slice_ds, xr.Dataset)
129130

130-
# noinspection PyMethodMayBeStatic
131131
def test_slice_item_is_uri_with_polling_fail(self):
132132
slice_dir = FileObj("memory://slice.zarr")
133133
ctx = Context(
@@ -140,3 +140,114 @@ def test_slice_item_is_uri_with_polling_fail(self):
140140
with pytest.raises(FileNotFoundError, match=slice_dir.uri):
141141
with slice_cm:
142142
pass
143+
144+
def test_slice_item_is_context_manager(self):
145+
@contextlib.contextmanager
146+
def get_dataset(name):
147+
uri = f"memory://{name}.zarr"
148+
ds = make_test_dataset(uri=uri)
149+
try:
150+
yield ds
151+
finally:
152+
ds.close()
153+
FileObj(uri).delete(recursive=True)
154+
155+
ctx = Context(
156+
dict(
157+
target_dir="memory://target.zarr",
158+
slice_source=get_dataset,
159+
)
160+
)
161+
slice_cm = open_slice_dataset(ctx, "bibo")
162+
self.assertIsInstance(slice_cm, contextlib.AbstractContextManager)
163+
with slice_cm as slice_ds:
164+
self.assertIsInstance(slice_ds, xr.Dataset)
165+
166+
def test_slice_item_is_slice_source(self):
167+
class MySliceSource(SliceSource):
168+
def __init__(self, name):
169+
self.uri = f"memory://{name}.zarr"
170+
self.ds = None
171+
172+
def get_dataset(self):
173+
self.ds = make_test_dataset(uri=self.uri)
174+
return self.ds
175+
176+
def close(self):
177+
if self.ds is not None:
178+
self.ds.close()
179+
FileObj(uri=self.uri).delete(recursive=True)
180+
181+
ctx = Context(
182+
dict(
183+
target_dir="memory://target.zarr",
184+
slice_source=MySliceSource,
185+
)
186+
)
187+
slice_cm = open_slice_dataset(ctx, "bibo")
188+
self.assertIsInstance(slice_cm, SliceSourceContextManager)
189+
self.assertIsInstance(slice_cm.slice_source, SliceSource)
190+
with slice_cm as slice_ds:
191+
self.assertIsInstance(slice_ds, xr.Dataset)
192+
193+
def test_slice_item_is_deprecated_slice_source(self):
194+
class MySliceSource(SliceSource):
195+
def __init__(self, name):
196+
self.uri = f"memory://{name}.zarr"
197+
self.ds = None
198+
199+
def get_dataset(self):
200+
self.ds = make_test_dataset(uri=self.uri)
201+
return self.ds
202+
203+
def dispose(self):
204+
if self.ds is not None:
205+
self.ds.close()
206+
FileObj(uri=self.uri).delete(recursive=True)
207+
208+
ctx = Context(
209+
dict(
210+
target_dir="memory://target.zarr",
211+
slice_source=MySliceSource,
212+
)
213+
)
214+
slice_cm = open_slice_dataset(ctx, "bibo")
215+
self.assertIsInstance(slice_cm, SliceSourceContextManager)
216+
self.assertIsInstance(slice_cm.slice_source, SliceSource)
217+
with pytest.warns(expected_warning=DeprecationWarning):
218+
with slice_cm as slice_ds:
219+
self.assertIsInstance(slice_ds, xr.Dataset)
220+
221+
222+
class IsContextManagerTest(unittest.TestCase):
223+
"""Assert that context managers are identified by isinstance()"""
224+
225+
def test_context_manager_class(self):
226+
@contextlib.contextmanager
227+
def my_slice_source(data):
228+
ds = xr.Dataset(data)
229+
try:
230+
yield ds
231+
finally:
232+
ds.close()
233+
234+
item = my_slice_source([1, 2, 3])
235+
self.assertTrue(isinstance(item, contextlib.AbstractContextManager))
236+
self.assertFalse(isinstance(my_slice_source, contextlib.AbstractContextManager))
237+
238+
def test_context_manager_protocol(self):
239+
class MySliceSource:
240+
def __enter__(self):
241+
return xr.Dataset()
242+
243+
def __exit__(self, *exc):
244+
pass
245+
246+
item = MySliceSource()
247+
self.assertTrue(isinstance(item, contextlib.AbstractContextManager))
248+
self.assertFalse(isinstance(MySliceSource, contextlib.AbstractContextManager))
249+
250+
def test_dataset(self):
251+
item = xr.Dataset()
252+
self.assertTrue(isinstance(item, contextlib.AbstractContextManager))
253+
self.assertFalse(isinstance(xr.Dataset, contextlib.AbstractContextManager))

tests/slice/test_source.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,19 @@
22
# Permissions are hereby granted under the terms of the MIT License:
33
# https://opensource.org/licenses/MIT.
44

5-
import shutil
65
import unittest
7-
import warnings
86

97
import pytest
108
import xarray as xr
119

1210
from zappend.context import Context
1311
from zappend.fsutil.fileobj import FileObj
14-
from zappend.slice.cm import SliceSourceContextManager
15-
from zappend.slice.cm import open_slice_dataset
16-
from zappend.slice.source import to_slice_source, SliceSource
12+
from zappend.slice.source import SliceSource
13+
from zappend.slice.source import to_slice_source
1714
from zappend.slice.sources.memory import MemorySliceSource
1815
from zappend.slice.sources.persistent import PersistentSliceSource
1916
from zappend.slice.sources.temporary import TemporarySliceSource
2017
from tests.helpers import clear_memory_fs
21-
from tests.helpers import make_test_dataset
22-
from tests.config.test_config import CustomSliceSource
2318

2419

2520
# noinspection PyUnusedLocal
@@ -86,22 +81,43 @@ def my_slice_source(arg1, arg2=None, ctx=None):
8681
return xr.Dataset(attrs=dict(arg1=arg1, arg2=arg2, ctx=ctx))
8782

8883
ctx = make_ctx(slice_source=my_slice_source)
89-
arg = xr.Dataset()
9084
slice_source = to_slice_source(ctx, ([13], {"arg2": True}), 0)
9185
self.assertIsInstance(slice_source, MemorySliceSource)
9286
ds = slice_source.get_dataset()
9387
self.assertEqual(13, ds.attrs.get("arg1"))
9488
self.assertEqual(True, ds.attrs.get("arg2"))
9589
self.assertIs(ctx, ds.attrs.get("ctx"))
9690

91+
def test_slice_item_is_slice_source_context_manager(self):
92+
import contextlib
93+
94+
@contextlib.contextmanager
95+
def my_slice_source(ctx, arg1, arg2=None):
96+
_ds = xr.Dataset(attrs=dict(arg1=arg1, arg2=arg2, ctx=ctx))
97+
try:
98+
yield _ds
99+
finally:
100+
_ds.close()
101+
102+
ctx = make_ctx(slice_source=my_slice_source)
103+
slice_source = to_slice_source(ctx, ([14], {"arg2": "OK"}), 0)
104+
self.assertIsInstance(slice_source, contextlib.AbstractContextManager)
105+
with slice_source as ds:
106+
self.assertIsInstance(ds, xr.Dataset)
107+
self.assertEqual(14, ds.attrs.get("arg1"))
108+
self.assertEqual("OK", ds.attrs.get("arg2"))
109+
self.assertIs(ctx, ds.attrs.get("ctx"))
110+
97111
# noinspection PyMethodMayBeStatic
98112
def test_raises_if_slice_item_is_int(self):
99113
ctx = make_ctx(persist_mem_slices=True)
100114
with pytest.raises(
101115
TypeError,
102116
match=(
103117
"slice_item must have type str, xarray.Dataset,"
104-
" zappend.api.FileObj, zappend.api.SliceSource, but was type int"
118+
" contextlib.AbstractContextManager,"
119+
" zappend.api.FileObj, zappend.api.SliceSource,"
120+
" but was type int"
105121
),
106122
):
107123
to_slice_source(ctx, 42, 0)
@@ -116,7 +132,9 @@ def hallo():
116132
TypeError,
117133
match=(
118134
"slice_item must have type str, xarray.Dataset,"
119-
" zappend.api.FileObj, zappend.api.SliceSource, but was type function"
135+
" contextlib.AbstractContextManager,"
136+
" zappend.api.FileObj, zappend.api.SliceSource,"
137+
" but was type function"
120138
),
121139
):
122140
to_slice_source(ctx, hallo, 0)

tests/test_api.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def __init__(self, slice_ds):
150150
def get_dataset(self) -> xr.Dataset:
151151
return self.slice_ds.drop_vars(["tsm"])
152152

153-
def dispose(self):
153+
def close(self):
154154
pass
155155

156156
target_dir = "memory://target.zarr"
@@ -168,6 +168,28 @@ def dispose(self):
168168
ds.attrs,
169169
)
170170

171+
def test_some_slices_with_slice_source_cm(self):
172+
import contextlib
173+
174+
@contextlib.contextmanager
175+
def drop_tsm(slice_ds: xr.Dataset) -> xr.Dataset:
176+
yield slice_ds.drop_vars(["tsm"])
177+
178+
target_dir = "memory://target.zarr"
179+
slices = [make_test_dataset(index=3 * i) for i in range(3)]
180+
zappend(slices, target_dir=target_dir, slice_source=drop_tsm)
181+
ds = xr.open_zarr(target_dir)
182+
self.assertEqual({"time": 9, "y": 50, "x": 100}, ds.sizes)
183+
self.assertEqual({"chl"}, set(ds.data_vars))
184+
self.assertEqual({"time", "y", "x"}, set(ds.coords))
185+
self.assertEqual(
186+
{
187+
"Conventions": "CF-1.8",
188+
"title": "Test 1-3",
189+
},
190+
ds.attrs,
191+
)
192+
171193
def test_some_slices_with_slice_source_func(self):
172194
def drop_tsm(slice_ds: xr.Dataset) -> xr.Dataset:
173195
return slice_ds.drop_vars(["tsm"])

zappend/slice/cm.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# https://opensource.org/licenses/MIT.
44

55
import contextlib
6-
from typing import Any
6+
from typing import Any, ContextManager
77

88
import xarray as xr
99

@@ -15,6 +15,8 @@
1515
class SliceSourceContextManager(contextlib.AbstractContextManager):
1616
"""A context manager that wraps a slice source.
1717
18+
Internal class, no API.
19+
1820
Args:
1921
slice_source: The slice source.
2022
"""
@@ -30,15 +32,15 @@ def __enter__(self) -> xr.Dataset:
3032
return self._slice_source.get_dataset()
3133

3234
def __exit__(self, *exception_args):
33-
self._slice_source.dispose()
35+
self._slice_source.close()
3436
self._slice_source = None
3537

3638

3739
def open_slice_dataset(
3840
ctx: Context,
3941
slice_item: Any,
4042
slice_index: int = 0,
41-
) -> SliceSourceContextManager:
43+
) -> ContextManager[xr.Dataset]:
4244
"""Open the slice source for given slice item `slice_item`.
4345
4446
The intended and only use of the returned slice source is as context
@@ -77,4 +79,7 @@ class derived from `zappend.slice.SliceSource` or a function that returns
7779
A new slice source instance
7880
"""
7981
slice_source = to_slice_source(ctx, slice_item, slice_index)
80-
return SliceSourceContextManager(slice_source)
82+
if isinstance(slice_source, contextlib.AbstractContextManager):
83+
return slice_source
84+
else:
85+
return SliceSourceContextManager(slice_source)

0 commit comments

Comments
 (0)