22# Permissions are hereby granted under the terms of the MIT License:
33# https://opensource.org/licenses/MIT.
44
5+ import contextlib
56import shutil
67import unittest
78import warnings
1314from zappend .fsutil .fileobj import FileObj
1415from zappend .slice .cm import SliceSourceContextManager
1516from zappend .slice .cm import open_slice_dataset
17+ from zappend .slice .source import SliceSource
1618from zappend .slice .sources .memory import MemorySliceSource
1719from zappend .slice .sources .persistent import PersistentSliceSource
1820from zappend .slice .sources .temporary import TemporarySliceSource
2325# noinspection PyUnusedLocal
2426
2527
26- # noinspection PyShadowingBuiltins
28+ # noinspection PyShadowingBuiltins,PyRedeclaration
2729class OpenSliceDatasetTest (unittest .TestCase ):
2830 def setUp (self ):
2931 clear_memory_fs ()
3032
31- # noinspection PyMethodMayBeStatic
3233 def test_slice_item_is_slice_source (self ):
3334 dataset = make_test_dataset ()
3435 ctx = Context (dict (target_dir = "memory://target.zarr" ))
@@ -127,7 +128,6 @@ def test_slice_item_is_uri_with_polling_ok(self):
127128 with slice_cm as slice_ds :
128129 self .assertIsInstance (slice_ds , xr .Dataset )
129130
130- # noinspection PyMethodMayBeStatic
131131 def test_slice_item_is_uri_with_polling_fail (self ):
132132 slice_dir = FileObj ("memory://slice.zarr" )
133133 ctx = Context (
@@ -140,3 +140,114 @@ def test_slice_item_is_uri_with_polling_fail(self):
140140 with pytest .raises (FileNotFoundError , match = slice_dir .uri ):
141141 with slice_cm :
142142 pass
143+
144+ def test_slice_item_is_context_manager (self ):
145+ @contextlib .contextmanager
146+ def get_dataset (name ):
147+ uri = f"memory://{ name } .zarr"
148+ ds = make_test_dataset (uri = uri )
149+ try :
150+ yield ds
151+ finally :
152+ ds .close ()
153+ FileObj (uri ).delete (recursive = True )
154+
155+ ctx = Context (
156+ dict (
157+ target_dir = "memory://target.zarr" ,
158+ slice_source = get_dataset ,
159+ )
160+ )
161+ slice_cm = open_slice_dataset (ctx , "bibo" )
162+ self .assertIsInstance (slice_cm , contextlib .AbstractContextManager )
163+ with slice_cm as slice_ds :
164+ self .assertIsInstance (slice_ds , xr .Dataset )
165+
166+ def test_slice_item_is_slice_source (self ):
167+ class MySliceSource (SliceSource ):
168+ def __init__ (self , name ):
169+ self .uri = f"memory://{ name } .zarr"
170+ self .ds = None
171+
172+ def get_dataset (self ):
173+ self .ds = make_test_dataset (uri = self .uri )
174+ return self .ds
175+
176+ def close (self ):
177+ if self .ds is not None :
178+ self .ds .close ()
179+ FileObj (uri = self .uri ).delete (recursive = True )
180+
181+ ctx = Context (
182+ dict (
183+ target_dir = "memory://target.zarr" ,
184+ slice_source = MySliceSource ,
185+ )
186+ )
187+ slice_cm = open_slice_dataset (ctx , "bibo" )
188+ self .assertIsInstance (slice_cm , SliceSourceContextManager )
189+ self .assertIsInstance (slice_cm .slice_source , SliceSource )
190+ with slice_cm as slice_ds :
191+ self .assertIsInstance (slice_ds , xr .Dataset )
192+
193+ def test_slice_item_is_deprecated_slice_source (self ):
194+ class MySliceSource (SliceSource ):
195+ def __init__ (self , name ):
196+ self .uri = f"memory://{ name } .zarr"
197+ self .ds = None
198+
199+ def get_dataset (self ):
200+ self .ds = make_test_dataset (uri = self .uri )
201+ return self .ds
202+
203+ def dispose (self ):
204+ if self .ds is not None :
205+ self .ds .close ()
206+ FileObj (uri = self .uri ).delete (recursive = True )
207+
208+ ctx = Context (
209+ dict (
210+ target_dir = "memory://target.zarr" ,
211+ slice_source = MySliceSource ,
212+ )
213+ )
214+ slice_cm = open_slice_dataset (ctx , "bibo" )
215+ self .assertIsInstance (slice_cm , SliceSourceContextManager )
216+ self .assertIsInstance (slice_cm .slice_source , SliceSource )
217+ with pytest .warns (expected_warning = DeprecationWarning ):
218+ with slice_cm as slice_ds :
219+ self .assertIsInstance (slice_ds , xr .Dataset )
220+
221+
222+ class IsContextManagerTest (unittest .TestCase ):
223+ """Assert that context managers are identified by isinstance()"""
224+
225+ def test_context_manager_class (self ):
226+ @contextlib .contextmanager
227+ def my_slice_source (data ):
228+ ds = xr .Dataset (data )
229+ try :
230+ yield ds
231+ finally :
232+ ds .close ()
233+
234+ item = my_slice_source ([1 , 2 , 3 ])
235+ self .assertTrue (isinstance (item , contextlib .AbstractContextManager ))
236+ self .assertFalse (isinstance (my_slice_source , contextlib .AbstractContextManager ))
237+
238+ def test_context_manager_protocol (self ):
239+ class MySliceSource :
240+ def __enter__ (self ):
241+ return xr .Dataset ()
242+
243+ def __exit__ (self , * exc ):
244+ pass
245+
246+ item = MySliceSource ()
247+ self .assertTrue (isinstance (item , contextlib .AbstractContextManager ))
248+ self .assertFalse (isinstance (MySliceSource , contextlib .AbstractContextManager ))
249+
250+ def test_dataset (self ):
251+ item = xr .Dataset ()
252+ self .assertTrue (isinstance (item , contextlib .AbstractContextManager ))
253+ self .assertFalse (isinstance (xr .Dataset , contextlib .AbstractContextManager ))
0 commit comments