28
28
from toolz import frequencies
29
29
30
30
from dask ._compatibility import import_optional_dependency
31
+ from dask .core import flatten
31
32
32
33
xr = import_optional_dependency ("xarray" , errors = "ignore" )
33
34
60
61
persist ,
61
62
tokenize ,
62
63
)
64
+ from dask .blockwise import BlockwiseDep
63
65
from dask .blockwise import blockwise as core_blockwise
64
66
from dask .blockwise import broadcast_dimensions
65
67
from dask .context import globalmethod
68
70
from dask .highlevelgraph import HighLevelGraph , MaterializedLayer
69
71
from dask .layers import ArrayBlockIdDep , ArraySliceDep , ArrayValuesDep
70
72
from dask .sizeof import sizeof
71
- from dask .typing import Graph , Key , NestedKeys
73
+ from dask .typing import Graph , NestedKeys
72
74
from dask .utils import (
73
75
IndexCallable ,
74
76
SerializableLock ,
79
81
derived_from ,
80
82
format_bytes ,
81
83
funcname ,
84
+ get_scheduler_lock ,
82
85
has_keyword ,
83
86
is_arraylike ,
84
87
is_dataframe_like ,
@@ -824,11 +827,15 @@ def map_blocks(
824
827
new_axis = [new_axis ] # TODO: handle new_axis
825
828
826
829
arrs = [a for a in args if isinstance (a , Array )]
830
+ argpairs = []
831
+ for a in args :
832
+ if isinstance (a , Array ):
833
+ argpairs .append ((a , tuple (range (a .ndim ))[::- 1 ]))
834
+ elif isinstance (a , BlockwiseDep ):
835
+ argpairs .append ((a , tuple (range (args [0 ].ndim ))[::- 1 ]))
836
+ else :
837
+ argpairs .append ((a , None ))
827
838
828
- argpairs = [
829
- (a , tuple (range (a .ndim ))[::- 1 ]) if isinstance (a , Array ) else (a , None )
830
- for a in args
831
- ]
832
839
if arrs :
833
840
out_ind = tuple (range (max (a .ndim for a in arrs )))[::- 1 ]
834
841
else :
@@ -1189,85 +1196,54 @@ def store(
1189
1196
)
1190
1197
del regions
1191
1198
1192
- # Optimize all sources together
1193
- sources_hlg = HighLevelGraph .merge (* [e .__dask_graph__ () for e in sources ])
1194
- sources_layer = Array .__dask_optimize__ (
1195
- sources_hlg , list (core .flatten ([e .__dask_keys__ () for e in sources ]))
1196
- )
1197
- sources_name = "store-sources-" + tokenize (sources )
1198
- layers = {sources_name : sources_layer }
1199
- dependencies : dict [str , set [str ]] = {sources_name : set ()}
1200
-
1201
- # Optimize all targets together
1202
- targets_keys = []
1203
- targets_dsks = []
1204
- for t in targets :
1205
- if isinstance (t , Delayed ):
1206
- targets_keys .append (t .key )
1207
- targets_dsks .append (t .__dask_graph__ ())
1208
- elif is_dask_collection (t ):
1209
- raise TypeError ("Targets must be either Delayed objects or array-likes" )
1210
-
1211
- if targets_dsks :
1212
- targets_hlg = HighLevelGraph .merge (* targets_dsks )
1213
- targets_layer = Delayed .__dask_optimize__ (targets_hlg , targets_keys )
1214
- targets_name = "store-targets-" + tokenize (targets_keys )
1215
- layers [targets_name ] = targets_layer
1216
- dependencies [targets_name ] = set ()
1217
-
1218
1199
if load_stored is None :
1219
1200
load_stored = return_stored and not compute
1220
1201
1221
- map_names = [
1222
- "store-map-" + tokenize (s , t if isinstance (t , Delayed ) else id (t ), r )
1223
- for s , t , r in zip (sources , targets , regions_list )
1224
- ]
1225
- map_keys : list [tuple ] = []
1226
-
1227
- for s , t , n , r in zip (sources , targets , map_names , regions_list ):
1228
- map_layer = insert_to_ooc (
1229
- keys = s .__dask_keys__ (),
1230
- chunks = s .chunks ,
1231
- out = t .key if isinstance (t , Delayed ) else t ,
1232
- name = n ,
1233
- lock = lock ,
1234
- region = r ,
1235
- return_stored = return_stored ,
1236
- load_stored = load_stored ,
1237
- )
1238
- layers [n ] = map_layer
1239
- if isinstance (t , Delayed ):
1240
- dependencies [n ] = {sources_name , targets_name }
1241
- else :
1242
- dependencies [n ] = {sources_name }
1243
- map_keys += map_layer .keys ()
1244
-
1245
- if return_stored :
1246
- store_dsk = HighLevelGraph (layers , dependencies )
1247
- load_store_dsk : HighLevelGraph | dict [tuple , Any ] = store_dsk
1248
- if compute :
1249
- store_dlyds = [Delayed (k , store_dsk , layer = k [0 ]) for k in map_keys ]
1250
- store_dlyds = persist (* store_dlyds , ** kwargs )
1251
- store_dsk_2 = HighLevelGraph .merge (* [e .dask for e in store_dlyds ])
1252
- load_store_dsk = retrieve_from_ooc (map_keys , store_dsk , store_dsk_2 )
1253
- map_names = ["load-" + n for n in map_names ]
1254
-
1255
- return tuple (
1256
- Array (load_store_dsk , n , s .chunks , meta = s )
1257
- for s , n in zip (sources , map_names )
1202
+ if lock is True :
1203
+ lock = get_scheduler_lock (collection = Array , scheduler = kwargs .get ("scheduler" ))
1204
+
1205
+ arrays = []
1206
+ for s , t , r in zip (sources , targets , regions_list ):
1207
+ slices = ArraySliceDep (s .chunks )
1208
+ arrays .append (
1209
+ s .map_blocks (
1210
+ load_store_chunk , # type: ignore[arg-type]
1211
+ t ,
1212
+ # Note: slices / BlockwiseDep have to be passed by arg, not by kwarg
1213
+ slices ,
1214
+ region = r ,
1215
+ lock = lock ,
1216
+ return_stored = return_stored ,
1217
+ load_stored = load_stored ,
1218
+ name = "store-map" ,
1219
+ meta = s ._meta ,
1220
+ )
1258
1221
)
1259
1222
1260
- elif compute :
1261
- store_dsk = HighLevelGraph (layers , dependencies )
1262
- compute_as_if_collection (Array , store_dsk , map_keys , ** kwargs )
1263
- return None
1223
+ if compute :
1224
+ if not return_stored :
1225
+ import dask
1264
1226
1265
- else :
1266
- key = "store-" + tokenize (map_names )
1267
- layers [key ] = {key : map_keys }
1268
- dependencies [key ] = set (map_names )
1269
- store_dsk = HighLevelGraph (layers , dependencies )
1270
- return Delayed (key , store_dsk )
1227
+ dask .compute (arrays , ** kwargs )
1228
+ return None
1229
+ else :
1230
+ stored_persisted = persist (* arrays , ** kwargs )
1231
+ arrays = []
1232
+ for s , r in zip (stored_persisted , regions_list ):
1233
+ slices = ArraySliceDep (s .chunks )
1234
+ arrays .append (
1235
+ s .map_blocks (
1236
+ load_chunk , # type: ignore[arg-type]
1237
+ # Note: slices / BlockwiseDep have to be passed by arg, not by kwarg
1238
+ slices ,
1239
+ lock = lock ,
1240
+ region = r ,
1241
+ meta = s ._meta ,
1242
+ )
1243
+ )
1244
+ if len (arrays ) == 1 :
1245
+ return arrays [0 ]
1246
+ return tuple (arrays )
1271
1247
1272
1248
1273
1249
def blockdims_from_blockshape (shape , chunks ):
@@ -1816,12 +1792,7 @@ def _elemwise(self):
1816
1792
1817
1793
@wraps (store )
1818
1794
def store (self , target , ** kwargs ):
1819
- r = store ([self ], [target ], ** kwargs )
1820
-
1821
- if kwargs .get ("return_stored" , False ):
1822
- r = r [0 ]
1823
-
1824
- return r
1795
+ return store ([self ], [target ], ** kwargs )
1825
1796
1826
1797
def to_svg (self , size = 500 ):
1827
1798
"""Convert chunks from Dask Array into an SVG Image
@@ -4567,11 +4538,12 @@ def concatenate(seq, axis=0, allow_unknown_chunksizes=False):
4567
4538
def load_store_chunk (
4568
4539
x : Any ,
4569
4540
out : Any ,
4570
- index : slice ,
4541
+ index : slice | None ,
4542
+ region : slice | None ,
4571
4543
lock : Any ,
4572
4544
return_stored : bool ,
4573
4545
load_stored : bool ,
4574
- ):
4546
+ ) -> Any :
4575
4547
"""
4576
4548
A function inserted in a Dask graph for storing a chunk.
4577
4549
@@ -4606,8 +4578,14 @@ def load_store_chunk(
4606
4578
4607
4579
>>> a = np.ones((5, 6))
4608
4580
>>> b = np.empty(a.shape)
4609
- >>> load_store_chunk(a, b, (slice(None), slice(None)), False, False, False)
4581
+ >>> load_store_chunk(a, b, (slice(None), slice(None)), None, False, False, False)
4610
4582
"""
4583
+ if region :
4584
+ # Equivalent to `out[region][index]`
4585
+ if index :
4586
+ index = fuse_slice (region , index )
4587
+ else :
4588
+ index = region
4611
4589
if lock :
4612
4590
lock .acquire ()
4613
4591
try :
@@ -4628,119 +4606,19 @@ def load_store_chunk(
4628
4606
lock .release ()
4629
4607
4630
4608
4631
- def store_chunk (
4632
- x : ArrayLike , out : ArrayLike , index : slice , lock : Any , return_stored : bool
4633
- ):
4634
- return load_store_chunk (x , out , index , lock , return_stored , False )
4635
-
4636
-
4637
4609
A = TypeVar ("A" , bound = ArrayLike )
4638
4610
4639
4611
4640
- def load_chunk (out : A , index : slice , lock : Any ) -> A :
4641
- return load_store_chunk (None , out , index , lock , True , True )
4642
-
4643
-
4644
- def insert_to_ooc (
4645
- keys : list ,
4646
- chunks : tuple [tuple [int , ...], ...],
4647
- out : ArrayLike ,
4648
- name : str ,
4649
- * ,
4650
- lock : Lock | bool = True ,
4651
- region : tuple [slice , ...] | slice | None = None ,
4652
- return_stored : bool = False ,
4653
- load_stored : bool = False ,
4654
- ) -> dict :
4655
- """
4656
- Creates a Dask graph for storing chunks from ``arr`` in ``out``.
4657
-
4658
- Parameters
4659
- ----------
4660
- keys: list
4661
- Dask keys of the input array
4662
- chunks: tuple
4663
- Dask chunks of the input array
4664
- out: array-like
4665
- Where to store results to
4666
- name: str
4667
- First element of dask keys
4668
- lock: Lock-like or bool, optional
4669
- Whether to lock or with what (default is ``True``,
4670
- which means a :class:`threading.Lock` instance).
4671
- region: slice-like, optional
4672
- Where in ``out`` to store ``arr``'s results
4673
- (default is ``None``, meaning all of ``out``).
4674
- return_stored: bool, optional
4675
- Whether to return ``out``
4676
- (default is ``False``, meaning ``None`` is returned).
4677
- load_stored: bool, optional
4678
- Whether to handling loading from ``out`` at the same time.
4679
- Ignored if ``return_stored`` is not ``True``.
4680
- (default is ``False``, meaning defer to ``return_stored``).
4681
-
4682
- Returns
4683
- -------
4684
- dask graph of store operation
4685
-
4686
- Examples
4687
- --------
4688
- >>> import dask.array as da
4689
- >>> d = da.ones((5, 6), chunks=(2, 3))
4690
- >>> a = np.empty(d.shape)
4691
- >>> insert_to_ooc(d.__dask_keys__(), d.chunks, a, "store-123") # doctest: +SKIP
4692
- """
4693
-
4694
- if lock is True :
4695
- lock = Lock ()
4696
-
4697
- slices = slices_from_chunks (chunks )
4698
- if region :
4699
- slices = [fuse_slice (region , slc ) for slc in slices ]
4700
-
4701
- if return_stored and load_stored :
4702
- func = load_store_chunk
4703
- args = (load_stored ,)
4704
- else :
4705
- func = store_chunk # type: ignore
4706
- args = () # type: ignore
4707
-
4708
- dsk = {
4709
- (name ,) + t [1 :]: (func , t , out , slc , lock , return_stored ) + args
4710
- for t , slc in zip (core .flatten (keys ), slices )
4711
- }
4712
- return dsk
4713
-
4714
-
4715
- def retrieve_from_ooc (
4716
- keys : Collection [Key ], dsk_pre : Graph , dsk_post : Graph
4717
- ) -> dict [tuple , Any ]:
4718
- """
4719
- Creates a Dask graph for loading stored ``keys`` from ``dsk``.
4720
-
4721
- Parameters
4722
- ----------
4723
- keys: Collection
4724
- A sequence containing Dask graph keys to load
4725
- dsk_pre: Mapping
4726
- A Dask graph corresponding to a Dask Array before computation
4727
- dsk_post: Mapping
4728
- A Dask graph corresponding to a Dask Array after computation
4729
-
4730
- Examples
4731
- --------
4732
- >>> import dask.array as da
4733
- >>> d = da.ones((5, 6), chunks=(2, 3))
4734
- >>> a = np.empty(d.shape)
4735
- >>> g = insert_to_ooc(d.__dask_keys__(), d.chunks, a, "store-123")
4736
- >>> retrieve_from_ooc(g.keys(), g, {k: k for k in g.keys()}) # doctest: +SKIP
4737
- """
4738
- load_dsk = {
4739
- ("load-" + k [0 ],) + k [1 :]: (load_chunk , dsk_post [k ]) + dsk_pre [k ][3 :- 1 ] # type: ignore
4740
- for k in keys
4741
- }
4742
-
4743
- return load_dsk
4612
+ def load_chunk (out : A , index : slice , lock : Any , region : slice | None ) -> A :
4613
+ return load_store_chunk (
4614
+ None ,
4615
+ out = out ,
4616
+ region = region ,
4617
+ index = index ,
4618
+ lock = lock ,
4619
+ return_stored = True ,
4620
+ load_stored = True ,
4621
+ )
4744
4622
4745
4623
4746
4624
def _as_dtype (a , dtype ):
@@ -5546,7 +5424,7 @@ def concatenate3(arrays):
5546
5424
NDARRAY_ARRAY_FUNCTION = getattr (np .ndarray , "__array_function__" , None )
5547
5425
5548
5426
arrays = concrete (arrays )
5549
- if not arrays :
5427
+ if not arrays or all ( el is None for el in flatten ( arrays )) :
5550
5428
return np .empty (0 )
5551
5429
5552
5430
advanced = max (
0 commit comments