Skip to content

Commit 29aaf6c

Browse files
authored
Sharrow updates (#52)
* add expr to error message * extra logging * fix deprecation * ruffen * dask_scheduler * faster sharrow for missing categoricals
1 parent 085ba42 commit 29aaf6c

File tree

4 files changed

+93
-5
lines changed

4 files changed

+93
-5
lines changed

sharrow/aster.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,9 @@ def _replacement(
408408
if self.get_default or (
409409
topname == pref_topname and not self.swallow_errors
410410
):
411-
raise KeyError(f"{topname}..{attr}")
411+
raise KeyError(
412+
f"{topname}..{attr}\nexpression={self.original_expr}"
413+
)
412414
# we originally raised a KeyError here regardless, but what if
413415
# we just give back the original node, and see if other spaces,
414416
# possibly fallback spaces, might work? If nothing works then
@@ -1010,6 +1012,16 @@ def visit_Compare(self, node):
10101012
f"\ncategories: {left_dictionary}",
10111013
stacklevel=2,
10121014
)
1015+
# at this point, the right value is not in the left's categories, so
1016+
# it is guaranteed to be not equal to any of the categories.
1017+
if isinstance(node.ops[0], ast.Eq):
1018+
result = ast.Constant(False)
1019+
elif isinstance(node.ops[0], ast.NotEq):
1020+
result = ast.Constant(True)
1021+
else:
1022+
raise ValueError(
1023+
f"unexpected operator {node.ops[0]}"
1024+
) from None
10131025
if right_decoded is not None:
10141026
result = ast.Compare(
10151027
left=left.slice,
@@ -1043,6 +1055,16 @@ def visit_Compare(self, node):
10431055
f"\ncategories: {right_dictionary}",
10441056
stacklevel=2,
10451057
)
1058+
# at this point, the left value is not in the right's categories, so
1059+
# it is guaranteed to be not equal to any of the categories.
1060+
if isinstance(node.ops[0], ast.Eq):
1061+
result = ast.Constant(False)
1062+
elif isinstance(node.ops[0], ast.NotEq):
1063+
result = ast.Constant(True)
1064+
else:
1065+
raise ValueError(
1066+
f"unexpected operator {node.ops[0]}"
1067+
) from None
10461068
if left_decoded is not None:
10471069
result = ast.Compare(
10481070
left=ast_Constant(left_decoded),

sharrow/shared_memory.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
import os
55
import pickle
6+
import time
67

78
import dask
89
import dask.array as da
@@ -247,7 +248,9 @@ def release_shared_memory(self):
247248
def delete_shared_memory_files(key):
248249
delete_shared_memory_files(key)
249250

250-
def to_shared_memory(self, key=None, mode="r+", _dupe=True):
251+
def to_shared_memory(
252+
self, key=None, mode="r+", _dupe=True, dask_scheduler="threads"
253+
):
251254
"""
252255
Load this Dataset into shared memory.
253256
@@ -262,9 +265,13 @@ def to_shared_memory(self, key=None, mode="r+", _dupe=True):
262265
An identifying key for this shared memory. Use the same key
263266
in `from_shared_memory` to recreate this Dataset elsewhere.
264267
mode : {‘r+’, ‘r’, ‘w+’, ‘c’}, optional
265-
This methid returns a copy of the Dataset in shared memory.
268+
This method returns a copy of the Dataset in shared memory.
266269
If memmapped, that copy can be opened in various modes.
267270
See numpy.memmap() for details.
271+
dask_scheduler : str, default 'threads'
272+
The scheduler to use when loading dask arrays into shared memory.
273+
Typically "threads" for multi-threaded reads or "synchronous"
274+
for single-threaded reads. See dask.compute() for details.
268275
269276
Returns
270277
-------
@@ -287,6 +294,7 @@ def to_shared_memory(self, key=None, mode="r+", _dupe=True):
287294
def emit(k, a, is_coord):
288295
nonlocal names, wrappers, sizes, position
289296
if sparse is not None and isinstance(a.data, sparse.GCXS):
297+
logger.info(f"preparing sparse array {a.name}")
290298
wrappers.append(
291299
{
292300
"sparse": True,
@@ -308,6 +316,7 @@ def emit(k, a, is_coord):
308316
)
309317
a_nbytes = a.data.nbytes
310318
else:
319+
logger.info(f"preparing dense array {a.name}")
311320
wrappers.append(
312321
{
313322
"dims": a.dims,
@@ -335,19 +344,23 @@ def emit(k, a, is_coord):
335344
emit(k, a, False)
336345

337346
mem = create_shared_memory_array(key, size=position)
347+
348+
logger.info("declaring shared memory buffer")
338349
if key.startswith("memmap:"):
339350
buffer = memoryview(mem)
340351
else:
341352
buffer = mem.buf
342353

343354
tasks = []
355+
task_names = []
344356
for w in wrappers:
345357
_is_sparse = w.get("sparse", False)
346358
_size = w["nbytes"]
347359
_name = w["name"]
348360
_pos = w["position"]
349361
a = self._obj[_name]
350362
if _is_sparse:
363+
logger.info(f"running load task: {_name} ({si_units(_size)})")
351364
ad = a.data
352365
_size_d = w["data.nbytes"]
353366
_size_i = w["indices.nbytes"]
@@ -373,19 +386,30 @@ def emit(k, a, is_coord):
373386
mem_arr_i[:] = ad.indices[:]
374387
mem_arr_p[:] = ad.indptr[:]
375388
else:
389+
logger.info(f"preparing load task: {_name} ({si_units(_size)})")
376390
mem_arr = np.ndarray(
377391
shape=a.shape, dtype=a.dtype, buffer=buffer[_pos : _pos + _size]
378392
)
379393
if isinstance(a, xr.DataArray) and isinstance(a.data, da.Array):
380394
tasks.append(da.store(a.data, mem_arr, lock=False, compute=False))
395+
task_names.append(_name)
381396
else:
382397
mem_arr[:] = a[:]
383398
if tasks:
384-
dask.compute(tasks, scheduler="threads")
399+
t = time.time()
400+
logger.info(f"running {len(tasks)} dask data load tasks")
401+
if dask_scheduler == "synchronous":
402+
for task, task_name in zip(tasks, task_names):
403+
logger.info(f"running load task: {task_name}")
404+
dask.compute(task, scheduler=dask_scheduler)
405+
else:
406+
dask.compute(tasks, scheduler=dask_scheduler)
407+
logger.info(f"completed dask data load in {time.time()-t:.3f} seconds")
385408

386409
if key.startswith("memmap:"):
387410
mem.flush()
388411

412+
logger.info("storing metadata in shared memory")
389413
create_shared_list(
390414
[pickle.dumps(self._obj.attrs)] + [pickle.dumps(i) for i in wrappers], key
391415
)

sharrow/sparse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def apply_mapper(x):
163163
raise ImportError("sparse is not installed")
164164

165165
sparse_data = sparse.GCXS(
166-
sparse.COO((i_, j_), data, shape=shape), compressed_axes=(0,)
166+
sparse.COO(np.stack((i_, j_)), data, shape=shape), compressed_axes=(0,)
167167
)
168168
self._obj[f"_s_{name}"] = xr.DataArray(
169169
sparse_data,

sharrow/tests/test_categorical.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,48 @@ def test_missing_categorical():
177177
a = a.isel(expressions=0)
178178
assert all(a == np.asarray([1, 0, 1, 1, 1, 1]))
179179

180+
expr = "df.TourMode2 != 'BAD'"
181+
with pytest.warns(UserWarning):
182+
f8 = tree.setup_flow({expr: expr}, with_root_node_name="df")
183+
a = f8.load_dataarray(dtype=np.int8)
184+
a = a.isel(expressions=0)
185+
assert all(a == np.asarray([1, 1, 1, 1, 1, 1]))
186+
187+
expr = "'BAD' != df.TourMode2"
188+
with pytest.warns(UserWarning):
189+
f9 = tree.setup_flow({expr: expr}, with_root_node_name="df")
190+
a = f9.load_dataarray(dtype=np.int8)
191+
a = a.isel(expressions=0)
192+
assert all(a == np.asarray([1, 1, 1, 1, 1, 1]))
193+
194+
expr = "(df.TourMode2 == 'BAD') * 2"
195+
with pytest.warns(UserWarning):
196+
fA = tree.setup_flow({expr: expr}, with_root_node_name="df")
197+
a = fA.load_dataarray(dtype=np.int8)
198+
a = a.isel(expressions=0)
199+
assert all(a == np.asarray([0, 0, 0, 0, 0, 0]))
200+
201+
expr = "(df.TourMode2 == 'BAD') * 2.2"
202+
with pytest.warns(UserWarning):
203+
fB = tree.setup_flow({expr: expr}, with_root_node_name="df")
204+
a = fB.load_dataarray(dtype=np.int8)
205+
a = a.isel(expressions=0)
206+
assert all(a == np.asarray([0, 0, 0, 0, 0, 0]))
207+
208+
expr = "np.exp(df.TourMode2 == 'BAD') * 2.2"
209+
with pytest.warns(UserWarning):
210+
fC = tree.setup_flow({expr: expr}, with_root_node_name="df")
211+
a = fC.load_dataarray(dtype=np.float32)
212+
a = a.isel(expressions=0)
213+
assert all(a == np.asarray([2.2, 2.2, 2.2, 2.2, 2.2, 2.2], dtype=np.float32))
214+
215+
expr = "(df.TourMode2 != 'BAD') * 2"
216+
with pytest.warns(UserWarning):
217+
fD = tree.setup_flow({expr: expr}, with_root_node_name="df")
218+
a = fD.load_dataarray(dtype=np.int8)
219+
a = a.isel(expressions=0)
220+
assert all(a == np.asarray([2, 2, 2, 2, 2, 2]))
221+
180222

181223
def test_categorical_indexing(tours_dataset: xr.Dataset, skims_dataset: xr.Dataset):
182224
tree = sharrow.DataTree(tours=tours_dataset)

0 commit comments

Comments
 (0)