2828from toolz import frequencies
2929
3030from dask ._compatibility import import_optional_dependency
31+ from dask .core import flatten
3132
3233xr = import_optional_dependency ("xarray" , errors = "ignore" )
3334
6061 persist ,
6162 tokenize ,
6263)
64+ from dask .blockwise import BlockwiseDep
6365from dask .blockwise import blockwise as core_blockwise
6466from dask .blockwise import broadcast_dimensions
6567from dask .context import globalmethod
6870from dask .highlevelgraph import HighLevelGraph , MaterializedLayer
6971from dask .layers import ArrayBlockIdDep , ArraySliceDep , ArrayValuesDep
7072from dask .sizeof import sizeof
71- from dask .typing import Graph , Key , NestedKeys
73+ from dask .typing import Graph , NestedKeys
7274from dask .utils import (
7375 IndexCallable ,
7476 SerializableLock ,
7981 derived_from ,
8082 format_bytes ,
8183 funcname ,
84+ get_scheduler_lock ,
8285 has_keyword ,
8386 is_arraylike ,
8487 is_dataframe_like ,
@@ -824,11 +827,15 @@ def map_blocks(
824827 new_axis = [new_axis ] # TODO: handle new_axis
825828
826829 arrs = [a for a in args if isinstance (a , Array )]
830+ argpairs = []
831+ for a in args :
832+ if isinstance (a , Array ):
833+ argpairs .append ((a , tuple (range (a .ndim ))[::- 1 ]))
834+ elif isinstance (a , BlockwiseDep ):
835+ argpairs .append ((a , tuple (range (args [0 ].ndim ))[::- 1 ]))
836+ else :
837+ argpairs .append ((a , None ))
827838
828- argpairs = [
829- (a , tuple (range (a .ndim ))[::- 1 ]) if isinstance (a , Array ) else (a , None )
830- for a in args
831- ]
832839 if arrs :
833840 out_ind = tuple (range (max (a .ndim for a in arrs )))[::- 1 ]
834841 else :
@@ -1189,85 +1196,54 @@ def store(
11891196 )
11901197 del regions
11911198
1192- # Optimize all sources together
1193- sources_hlg = HighLevelGraph .merge (* [e .__dask_graph__ () for e in sources ])
1194- sources_layer = Array .__dask_optimize__ (
1195- sources_hlg , list (core .flatten ([e .__dask_keys__ () for e in sources ]))
1196- )
1197- sources_name = "store-sources-" + tokenize (sources )
1198- layers = {sources_name : sources_layer }
1199- dependencies : dict [str , set [str ]] = {sources_name : set ()}
1200-
1201- # Optimize all targets together
1202- targets_keys = []
1203- targets_dsks = []
1204- for t in targets :
1205- if isinstance (t , Delayed ):
1206- targets_keys .append (t .key )
1207- targets_dsks .append (t .__dask_graph__ ())
1208- elif is_dask_collection (t ):
1209- raise TypeError ("Targets must be either Delayed objects or array-likes" )
1210-
1211- if targets_dsks :
1212- targets_hlg = HighLevelGraph .merge (* targets_dsks )
1213- targets_layer = Delayed .__dask_optimize__ (targets_hlg , targets_keys )
1214- targets_name = "store-targets-" + tokenize (targets_keys )
1215- layers [targets_name ] = targets_layer
1216- dependencies [targets_name ] = set ()
1217-
12181199 if load_stored is None :
12191200 load_stored = return_stored and not compute
12201201
1221- map_names = [
1222- "store-map-" + tokenize (s , t if isinstance (t , Delayed ) else id (t ), r )
1223- for s , t , r in zip (sources , targets , regions_list )
1224- ]
1225- map_keys : list [tuple ] = []
1226-
1227- for s , t , n , r in zip (sources , targets , map_names , regions_list ):
1228- map_layer = insert_to_ooc (
1229- keys = s .__dask_keys__ (),
1230- chunks = s .chunks ,
1231- out = t .key if isinstance (t , Delayed ) else t ,
1232- name = n ,
1233- lock = lock ,
1234- region = r ,
1235- return_stored = return_stored ,
1236- load_stored = load_stored ,
1237- )
1238- layers [n ] = map_layer
1239- if isinstance (t , Delayed ):
1240- dependencies [n ] = {sources_name , targets_name }
1241- else :
1242- dependencies [n ] = {sources_name }
1243- map_keys += map_layer .keys ()
1244-
1245- if return_stored :
1246- store_dsk = HighLevelGraph (layers , dependencies )
1247- load_store_dsk : HighLevelGraph | dict [tuple , Any ] = store_dsk
1248- if compute :
1249- store_dlyds = [Delayed (k , store_dsk , layer = k [0 ]) for k in map_keys ]
1250- store_dlyds = persist (* store_dlyds , ** kwargs )
1251- store_dsk_2 = HighLevelGraph .merge (* [e .dask for e in store_dlyds ])
1252- load_store_dsk = retrieve_from_ooc (map_keys , store_dsk , store_dsk_2 )
1253- map_names = ["load-" + n for n in map_names ]
1254-
1255- return tuple (
1256- Array (load_store_dsk , n , s .chunks , meta = s )
1257- for s , n in zip (sources , map_names )
1202+ if lock is True :
1203+ lock = get_scheduler_lock (collection = Array , scheduler = kwargs .get ("scheduler" ))
1204+
1205+ arrays = []
1206+ for s , t , r in zip (sources , targets , regions_list ):
1207+ slices = ArraySliceDep (s .chunks )
1208+ arrays .append (
1209+ s .map_blocks (
1210+ load_store_chunk , # type: ignore[arg-type]
1211+ t ,
1212+ # Note: slices / BlockwiseDep have to be passed by arg, not by kwarg
1213+ slices ,
1214+ region = r ,
1215+ lock = lock ,
1216+ return_stored = return_stored ,
1217+ load_stored = load_stored ,
1218+ name = "store-map" ,
1219+ meta = s ._meta ,
1220+ )
12581221 )
12591222
1260- elif compute :
1261- store_dsk = HighLevelGraph (layers , dependencies )
1262- compute_as_if_collection (Array , store_dsk , map_keys , ** kwargs )
1263- return None
1223+ if compute :
1224+ if not return_stored :
1225+ import dask
12641226
1265- else :
1266- key = "store-" + tokenize (map_names )
1267- layers [key ] = {key : map_keys }
1268- dependencies [key ] = set (map_names )
1269- store_dsk = HighLevelGraph (layers , dependencies )
1270- return Delayed (key , store_dsk )
1227+ dask .compute (arrays , ** kwargs )
1228+ return None
1229+ else :
1230+ stored_persisted = persist (* arrays , ** kwargs )
1231+ arrays = []
1232+ for s , r in zip (stored_persisted , regions_list ):
1233+ slices = ArraySliceDep (s .chunks )
1234+ arrays .append (
1235+ s .map_blocks (
1236+ load_chunk , # type: ignore[arg-type]
1237+ # Note: slices / BlockwiseDep have to be passed by arg, not by kwarg
1238+ slices ,
1239+ lock = lock ,
1240+ region = r ,
1241+ meta = s ._meta ,
1242+ )
1243+ )
1244+ if len (arrays ) == 1 :
1245+ return arrays [0 ]
1246+ return tuple (arrays )
12711247
12721248
12731249def blockdims_from_blockshape (shape , chunks ):
@@ -1816,12 +1792,7 @@ def _elemwise(self):
18161792
18171793 @wraps (store )
18181794 def store (self , target , ** kwargs ):
1819- r = store ([self ], [target ], ** kwargs )
1820-
1821- if kwargs .get ("return_stored" , False ):
1822- r = r [0 ]
1823-
1824- return r
1795+ return store ([self ], [target ], ** kwargs )
18251796
18261797 def to_svg (self , size = 500 ):
18271798 """Convert chunks from Dask Array into an SVG Image
@@ -4567,11 +4538,12 @@ def concatenate(seq, axis=0, allow_unknown_chunksizes=False):
45674538def load_store_chunk (
45684539 x : Any ,
45694540 out : Any ,
4570- index : slice ,
4541+ index : slice | None ,
4542+ region : slice | None ,
45714543 lock : Any ,
45724544 return_stored : bool ,
45734545 load_stored : bool ,
4574- ):
4546+ ) -> Any :
45754547 """
45764548 A function inserted in a Dask graph for storing a chunk.
45774549
@@ -4606,8 +4578,14 @@ def load_store_chunk(
46064578
46074579 >>> a = np.ones((5, 6))
46084580 >>> b = np.empty(a.shape)
4609- >>> load_store_chunk(a, b, (slice(None), slice(None)), False, False, False)
4581+ >>> load_store_chunk(a, b, (slice(None), slice(None)), None, False, False, False)
46104582 """
4583+ if region :
4584+ # Equivalent to `out[region][index]`
4585+ if index :
4586+ index = fuse_slice (region , index )
4587+ else :
4588+ index = region
46114589 if lock :
46124590 lock .acquire ()
46134591 try :
@@ -4628,119 +4606,19 @@ def load_store_chunk(
46284606 lock .release ()
46294607
46304608
4631- def store_chunk (
4632- x : ArrayLike , out : ArrayLike , index : slice , lock : Any , return_stored : bool
4633- ):
4634- return load_store_chunk (x , out , index , lock , return_stored , False )
4635-
4636-
46374609A = TypeVar ("A" , bound = ArrayLike )
46384610
46394611
4640- def load_chunk (out : A , index : slice , lock : Any ) -> A :
4641- return load_store_chunk (None , out , index , lock , True , True )
4642-
4643-
4644- def insert_to_ooc (
4645- keys : list ,
4646- chunks : tuple [tuple [int , ...], ...],
4647- out : ArrayLike ,
4648- name : str ,
4649- * ,
4650- lock : Lock | bool = True ,
4651- region : tuple [slice , ...] | slice | None = None ,
4652- return_stored : bool = False ,
4653- load_stored : bool = False ,
4654- ) -> dict :
4655- """
4656- Creates a Dask graph for storing chunks from ``arr`` in ``out``.
4657-
4658- Parameters
4659- ----------
4660- keys: list
4661- Dask keys of the input array
4662- chunks: tuple
4663- Dask chunks of the input array
4664- out: array-like
4665- Where to store results to
4666- name: str
4667- First element of dask keys
4668- lock: Lock-like or bool, optional
4669- Whether to lock or with what (default is ``True``,
4670- which means a :class:`threading.Lock` instance).
4671- region: slice-like, optional
4672- Where in ``out`` to store ``arr``'s results
4673- (default is ``None``, meaning all of ``out``).
4674- return_stored: bool, optional
4675- Whether to return ``out``
4676- (default is ``False``, meaning ``None`` is returned).
4677- load_stored: bool, optional
4678- Whether to handling loading from ``out`` at the same time.
4679- Ignored if ``return_stored`` is not ``True``.
4680- (default is ``False``, meaning defer to ``return_stored``).
4681-
4682- Returns
4683- -------
4684- dask graph of store operation
4685-
4686- Examples
4687- --------
4688- >>> import dask.array as da
4689- >>> d = da.ones((5, 6), chunks=(2, 3))
4690- >>> a = np.empty(d.shape)
4691- >>> insert_to_ooc(d.__dask_keys__(), d.chunks, a, "store-123") # doctest: +SKIP
4692- """
4693-
4694- if lock is True :
4695- lock = Lock ()
4696-
4697- slices = slices_from_chunks (chunks )
4698- if region :
4699- slices = [fuse_slice (region , slc ) for slc in slices ]
4700-
4701- if return_stored and load_stored :
4702- func = load_store_chunk
4703- args = (load_stored ,)
4704- else :
4705- func = store_chunk # type: ignore
4706- args = () # type: ignore
4707-
4708- dsk = {
4709- (name ,) + t [1 :]: (func , t , out , slc , lock , return_stored ) + args
4710- for t , slc in zip (core .flatten (keys ), slices )
4711- }
4712- return dsk
4713-
4714-
4715- def retrieve_from_ooc (
4716- keys : Collection [Key ], dsk_pre : Graph , dsk_post : Graph
4717- ) -> dict [tuple , Any ]:
4718- """
4719- Creates a Dask graph for loading stored ``keys`` from ``dsk``.
4720-
4721- Parameters
4722- ----------
4723- keys: Collection
4724- A sequence containing Dask graph keys to load
4725- dsk_pre: Mapping
4726- A Dask graph corresponding to a Dask Array before computation
4727- dsk_post: Mapping
4728- A Dask graph corresponding to a Dask Array after computation
4729-
4730- Examples
4731- --------
4732- >>> import dask.array as da
4733- >>> d = da.ones((5, 6), chunks=(2, 3))
4734- >>> a = np.empty(d.shape)
4735- >>> g = insert_to_ooc(d.__dask_keys__(), d.chunks, a, "store-123")
4736- >>> retrieve_from_ooc(g.keys(), g, {k: k for k in g.keys()}) # doctest: +SKIP
4737- """
4738- load_dsk = {
4739- ("load-" + k [0 ],) + k [1 :]: (load_chunk , dsk_post [k ]) + dsk_pre [k ][3 :- 1 ] # type: ignore
4740- for k in keys
4741- }
4742-
4743- return load_dsk
4612+ def load_chunk (out : A , index : slice , lock : Any , region : slice | None ) -> A :
4613+ return load_store_chunk (
4614+ None ,
4615+ out = out ,
4616+ region = region ,
4617+ index = index ,
4618+ lock = lock ,
4619+ return_stored = True ,
4620+ load_stored = True ,
4621+ )
47444622
47454623
47464624def _as_dtype (a , dtype ):
@@ -5546,7 +5424,7 @@ def concatenate3(arrays):
55465424 NDARRAY_ARRAY_FUNCTION = getattr (np .ndarray , "__array_function__" , None )
55475425
55485426 arrays = concrete (arrays )
5549- if not arrays :
5427+ if not arrays or all ( el is None for el in flatten ( arrays )) :
55505428 return np .empty (0 )
55515429
55525430 advanced = max (
0 commit comments