Skip to content

Commit cebcc04

Browse files
authored
Use map_blocks in array.store to avoid materialization and dropping of annotations (dask#11844)
1 parent 5a3dfb4 commit cebcc04

File tree

2 files changed

+138
-236
lines changed

2 files changed

+138
-236
lines changed

dask/array/core.py

Lines changed: 76 additions & 198 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from toolz import frequencies
2929

3030
from dask._compatibility import import_optional_dependency
31+
from dask.core import flatten
3132

3233
xr = import_optional_dependency("xarray", errors="ignore")
3334

@@ -60,6 +61,7 @@
6061
persist,
6162
tokenize,
6263
)
64+
from dask.blockwise import BlockwiseDep
6365
from dask.blockwise import blockwise as core_blockwise
6466
from dask.blockwise import broadcast_dimensions
6567
from dask.context import globalmethod
@@ -68,7 +70,7 @@
6870
from dask.highlevelgraph import HighLevelGraph, MaterializedLayer
6971
from dask.layers import ArrayBlockIdDep, ArraySliceDep, ArrayValuesDep
7072
from dask.sizeof import sizeof
71-
from dask.typing import Graph, Key, NestedKeys
73+
from dask.typing import Graph, NestedKeys
7274
from dask.utils import (
7375
IndexCallable,
7476
SerializableLock,
@@ -79,6 +81,7 @@
7981
derived_from,
8082
format_bytes,
8183
funcname,
84+
get_scheduler_lock,
8285
has_keyword,
8386
is_arraylike,
8487
is_dataframe_like,
@@ -824,11 +827,15 @@ def map_blocks(
824827
new_axis = [new_axis] # TODO: handle new_axis
825828

826829
arrs = [a for a in args if isinstance(a, Array)]
830+
argpairs = []
831+
for a in args:
832+
if isinstance(a, Array):
833+
argpairs.append((a, tuple(range(a.ndim))[::-1]))
834+
elif isinstance(a, BlockwiseDep):
835+
argpairs.append((a, tuple(range(args[0].ndim))[::-1]))
836+
else:
837+
argpairs.append((a, None))
827838

828-
argpairs = [
829-
(a, tuple(range(a.ndim))[::-1]) if isinstance(a, Array) else (a, None)
830-
for a in args
831-
]
832839
if arrs:
833840
out_ind = tuple(range(max(a.ndim for a in arrs)))[::-1]
834841
else:
@@ -1189,85 +1196,54 @@ def store(
11891196
)
11901197
del regions
11911198

1192-
# Optimize all sources together
1193-
sources_hlg = HighLevelGraph.merge(*[e.__dask_graph__() for e in sources])
1194-
sources_layer = Array.__dask_optimize__(
1195-
sources_hlg, list(core.flatten([e.__dask_keys__() for e in sources]))
1196-
)
1197-
sources_name = "store-sources-" + tokenize(sources)
1198-
layers = {sources_name: sources_layer}
1199-
dependencies: dict[str, set[str]] = {sources_name: set()}
1200-
1201-
# Optimize all targets together
1202-
targets_keys = []
1203-
targets_dsks = []
1204-
for t in targets:
1205-
if isinstance(t, Delayed):
1206-
targets_keys.append(t.key)
1207-
targets_dsks.append(t.__dask_graph__())
1208-
elif is_dask_collection(t):
1209-
raise TypeError("Targets must be either Delayed objects or array-likes")
1210-
1211-
if targets_dsks:
1212-
targets_hlg = HighLevelGraph.merge(*targets_dsks)
1213-
targets_layer = Delayed.__dask_optimize__(targets_hlg, targets_keys)
1214-
targets_name = "store-targets-" + tokenize(targets_keys)
1215-
layers[targets_name] = targets_layer
1216-
dependencies[targets_name] = set()
1217-
12181199
if load_stored is None:
12191200
load_stored = return_stored and not compute
12201201

1221-
map_names = [
1222-
"store-map-" + tokenize(s, t if isinstance(t, Delayed) else id(t), r)
1223-
for s, t, r in zip(sources, targets, regions_list)
1224-
]
1225-
map_keys: list[tuple] = []
1226-
1227-
for s, t, n, r in zip(sources, targets, map_names, regions_list):
1228-
map_layer = insert_to_ooc(
1229-
keys=s.__dask_keys__(),
1230-
chunks=s.chunks,
1231-
out=t.key if isinstance(t, Delayed) else t,
1232-
name=n,
1233-
lock=lock,
1234-
region=r,
1235-
return_stored=return_stored,
1236-
load_stored=load_stored,
1237-
)
1238-
layers[n] = map_layer
1239-
if isinstance(t, Delayed):
1240-
dependencies[n] = {sources_name, targets_name}
1241-
else:
1242-
dependencies[n] = {sources_name}
1243-
map_keys += map_layer.keys()
1244-
1245-
if return_stored:
1246-
store_dsk = HighLevelGraph(layers, dependencies)
1247-
load_store_dsk: HighLevelGraph | dict[tuple, Any] = store_dsk
1248-
if compute:
1249-
store_dlyds = [Delayed(k, store_dsk, layer=k[0]) for k in map_keys]
1250-
store_dlyds = persist(*store_dlyds, **kwargs)
1251-
store_dsk_2 = HighLevelGraph.merge(*[e.dask for e in store_dlyds])
1252-
load_store_dsk = retrieve_from_ooc(map_keys, store_dsk, store_dsk_2)
1253-
map_names = ["load-" + n for n in map_names]
1254-
1255-
return tuple(
1256-
Array(load_store_dsk, n, s.chunks, meta=s)
1257-
for s, n in zip(sources, map_names)
1202+
if lock is True:
1203+
lock = get_scheduler_lock(collection=Array, scheduler=kwargs.get("scheduler"))
1204+
1205+
arrays = []
1206+
for s, t, r in zip(sources, targets, regions_list):
1207+
slices = ArraySliceDep(s.chunks)
1208+
arrays.append(
1209+
s.map_blocks(
1210+
load_store_chunk, # type: ignore[arg-type]
1211+
t,
1212+
# Note: slices / BlockwiseDep have to be passed by arg, not by kwarg
1213+
slices,
1214+
region=r,
1215+
lock=lock,
1216+
return_stored=return_stored,
1217+
load_stored=load_stored,
1218+
name="store-map",
1219+
meta=s._meta,
1220+
)
12581221
)
12591222

1260-
elif compute:
1261-
store_dsk = HighLevelGraph(layers, dependencies)
1262-
compute_as_if_collection(Array, store_dsk, map_keys, **kwargs)
1263-
return None
1223+
if compute:
1224+
if not return_stored:
1225+
import dask
12641226

1265-
else:
1266-
key = "store-" + tokenize(map_names)
1267-
layers[key] = {key: map_keys}
1268-
dependencies[key] = set(map_names)
1269-
store_dsk = HighLevelGraph(layers, dependencies)
1270-
return Delayed(key, store_dsk)
1227+
dask.compute(arrays, **kwargs)
1228+
return None
1229+
else:
1230+
stored_persisted = persist(*arrays, **kwargs)
1231+
arrays = []
1232+
for s, r in zip(stored_persisted, regions_list):
1233+
slices = ArraySliceDep(s.chunks)
1234+
arrays.append(
1235+
s.map_blocks(
1236+
load_chunk, # type: ignore[arg-type]
1237+
# Note: slices / BlockwiseDep have to be passed by arg, not by kwarg
1238+
slices,
1239+
lock=lock,
1240+
region=r,
1241+
meta=s._meta,
1242+
)
1243+
)
1244+
if len(arrays) == 1:
1245+
return arrays[0]
1246+
return tuple(arrays)
12711247

12721248

12731249
def blockdims_from_blockshape(shape, chunks):
@@ -1816,12 +1792,7 @@ def _elemwise(self):
18161792

18171793
@wraps(store)
18181794
def store(self, target, **kwargs):
1819-
r = store([self], [target], **kwargs)
1820-
1821-
if kwargs.get("return_stored", False):
1822-
r = r[0]
1823-
1824-
return r
1795+
return store([self], [target], **kwargs)
18251796

18261797
def to_svg(self, size=500):
18271798
"""Convert chunks from Dask Array into an SVG Image
@@ -4567,11 +4538,12 @@ def concatenate(seq, axis=0, allow_unknown_chunksizes=False):
45674538
def load_store_chunk(
45684539
x: Any,
45694540
out: Any,
4570-
index: slice,
4541+
index: slice | None,
4542+
region: slice | None,
45714543
lock: Any,
45724544
return_stored: bool,
45734545
load_stored: bool,
4574-
):
4546+
) -> Any:
45754547
"""
45764548
A function inserted in a Dask graph for storing a chunk.
45774549
@@ -4606,8 +4578,14 @@ def load_store_chunk(
46064578
46074579
>>> a = np.ones((5, 6))
46084580
>>> b = np.empty(a.shape)
4609-
>>> load_store_chunk(a, b, (slice(None), slice(None)), False, False, False)
4581+
>>> load_store_chunk(a, b, (slice(None), slice(None)), None, False, False, False)
46104582
"""
4583+
if region:
4584+
# Equivalent to `out[region][index]`
4585+
if index:
4586+
index = fuse_slice(region, index)
4587+
else:
4588+
index = region
46114589
if lock:
46124590
lock.acquire()
46134591
try:
@@ -4628,119 +4606,19 @@ def load_store_chunk(
46284606
lock.release()
46294607

46304608

4631-
def store_chunk(
4632-
x: ArrayLike, out: ArrayLike, index: slice, lock: Any, return_stored: bool
4633-
):
4634-
return load_store_chunk(x, out, index, lock, return_stored, False)
4635-
4636-
46374609
A = TypeVar("A", bound=ArrayLike)
46384610

46394611

4640-
def load_chunk(out: A, index: slice, lock: Any) -> A:
4641-
return load_store_chunk(None, out, index, lock, True, True)
4642-
4643-
4644-
def insert_to_ooc(
4645-
keys: list,
4646-
chunks: tuple[tuple[int, ...], ...],
4647-
out: ArrayLike,
4648-
name: str,
4649-
*,
4650-
lock: Lock | bool = True,
4651-
region: tuple[slice, ...] | slice | None = None,
4652-
return_stored: bool = False,
4653-
load_stored: bool = False,
4654-
) -> dict:
4655-
"""
4656-
Creates a Dask graph for storing chunks from ``arr`` in ``out``.
4657-
4658-
Parameters
4659-
----------
4660-
keys: list
4661-
Dask keys of the input array
4662-
chunks: tuple
4663-
Dask chunks of the input array
4664-
out: array-like
4665-
Where to store results to
4666-
name: str
4667-
First element of dask keys
4668-
lock: Lock-like or bool, optional
4669-
Whether to lock or with what (default is ``True``,
4670-
which means a :class:`threading.Lock` instance).
4671-
region: slice-like, optional
4672-
Where in ``out`` to store ``arr``'s results
4673-
(default is ``None``, meaning all of ``out``).
4674-
return_stored: bool, optional
4675-
Whether to return ``out``
4676-
(default is ``False``, meaning ``None`` is returned).
4677-
load_stored: bool, optional
4678-
Whether to handling loading from ``out`` at the same time.
4679-
Ignored if ``return_stored`` is not ``True``.
4680-
(default is ``False``, meaning defer to ``return_stored``).
4681-
4682-
Returns
4683-
-------
4684-
dask graph of store operation
4685-
4686-
Examples
4687-
--------
4688-
>>> import dask.array as da
4689-
>>> d = da.ones((5, 6), chunks=(2, 3))
4690-
>>> a = np.empty(d.shape)
4691-
>>> insert_to_ooc(d.__dask_keys__(), d.chunks, a, "store-123") # doctest: +SKIP
4692-
"""
4693-
4694-
if lock is True:
4695-
lock = Lock()
4696-
4697-
slices = slices_from_chunks(chunks)
4698-
if region:
4699-
slices = [fuse_slice(region, slc) for slc in slices]
4700-
4701-
if return_stored and load_stored:
4702-
func = load_store_chunk
4703-
args = (load_stored,)
4704-
else:
4705-
func = store_chunk # type: ignore
4706-
args = () # type: ignore
4707-
4708-
dsk = {
4709-
(name,) + t[1:]: (func, t, out, slc, lock, return_stored) + args
4710-
for t, slc in zip(core.flatten(keys), slices)
4711-
}
4712-
return dsk
4713-
4714-
4715-
def retrieve_from_ooc(
4716-
keys: Collection[Key], dsk_pre: Graph, dsk_post: Graph
4717-
) -> dict[tuple, Any]:
4718-
"""
4719-
Creates a Dask graph for loading stored ``keys`` from ``dsk``.
4720-
4721-
Parameters
4722-
----------
4723-
keys: Collection
4724-
A sequence containing Dask graph keys to load
4725-
dsk_pre: Mapping
4726-
A Dask graph corresponding to a Dask Array before computation
4727-
dsk_post: Mapping
4728-
A Dask graph corresponding to a Dask Array after computation
4729-
4730-
Examples
4731-
--------
4732-
>>> import dask.array as da
4733-
>>> d = da.ones((5, 6), chunks=(2, 3))
4734-
>>> a = np.empty(d.shape)
4735-
>>> g = insert_to_ooc(d.__dask_keys__(), d.chunks, a, "store-123")
4736-
>>> retrieve_from_ooc(g.keys(), g, {k: k for k in g.keys()}) # doctest: +SKIP
4737-
"""
4738-
load_dsk = {
4739-
("load-" + k[0],) + k[1:]: (load_chunk, dsk_post[k]) + dsk_pre[k][3:-1] # type: ignore
4740-
for k in keys
4741-
}
4742-
4743-
return load_dsk
4612+
def load_chunk(out: A, index: slice, lock: Any, region: slice | None) -> A:
4613+
return load_store_chunk(
4614+
None,
4615+
out=out,
4616+
region=region,
4617+
index=index,
4618+
lock=lock,
4619+
return_stored=True,
4620+
load_stored=True,
4621+
)
47444622

47454623

47464624
def _as_dtype(a, dtype):
@@ -5546,7 +5424,7 @@ def concatenate3(arrays):
55465424
NDARRAY_ARRAY_FUNCTION = getattr(np.ndarray, "__array_function__", None)
55475425

55485426
arrays = concrete(arrays)
5549-
if not arrays:
5427+
if not arrays or all(el is None for el in flatten(arrays)):
55505428
return np.empty(0)
55515429

55525430
advanced = max(

0 commit comments

Comments
 (0)