83
83
from dask import array as da
84
84
from dask import bag as db
85
85
from dask import dataframe as dd
86
+ from dask .delayed import Delayed
87
+ from distributed import Future
86
88
87
89
from .. import collective , config
88
90
from .._typing import FeatureNames , FeatureTypes , IterationRange
@@ -336,7 +338,7 @@ def __init__(
336
338
337
339
self ._n_cols = data .shape [1 ]
338
340
assert isinstance (self ._n_cols , int )
339
- self .worker_map : Dict [str , List [distributed . Future ]] = defaultdict (list )
341
+ self .worker_map : Dict [str , List [Future ]] = defaultdict (list )
340
342
self .is_quantile : bool = False
341
343
342
344
self ._init = client .sync (
@@ -369,7 +371,6 @@ async def _map_local_data(
369
371
label_upper_bound : Optional [_DaskCollection ] = None ,
370
372
) -> "DaskDMatrix" :
371
373
"""Obtain references to local data."""
372
- from dask .delayed import Delayed
373
374
374
375
def inconsistent (
375
376
left : List [Any ], left_name : str , right : List [Any ], right_name : str
@@ -381,49 +382,39 @@ def inconsistent(
381
382
)
382
383
return msg
383
384
384
- def check_columns (parts : numpy .ndarray ) -> None :
385
- # x is required to be 2 dim in __init__
386
- assert parts .ndim == 1 or parts .shape [1 ], (
387
- "Data should be"
388
- " partitioned by row. To avoid this specify the number"
389
- " of columns for your dask Array explicitly. e.g."
390
- " chunks=(partition_size, X.shape[1])"
391
- )
392
-
393
- def to_delayed (d : _DaskCollection ) -> List [Delayed ]:
394
- """Breaking data into partitions, a trick borrowed from
395
- dask_xgboost. `to_delayed` downgrades high-level objects into numpy or
396
- pandas equivalents.
397
-
398
- """
385
+ def to_futures (d : _DaskCollection ) -> List [Future ]:
386
+ """Breaking data into partitions."""
399
387
d = client .persist (d )
400
- delayed_obj = d .to_delayed ()
401
- if isinstance (delayed_obj , numpy .ndarray ):
402
- # da.Array returns an array to delayed objects
403
- check_columns (delayed_obj )
404
- delayed_list : List [Delayed ] = delayed_obj .flatten ().tolist ()
405
- else :
406
- # dd.DataFrame
407
- delayed_list = delayed_obj
408
- return delayed_list
388
+ if (
389
+ hasattr (d .partitions , "shape" )
390
+ and len (d .partitions .shape ) > 1
391
+ and d .partitions .shape [1 ] > 1
392
+ ):
393
+ raise ValueError (
394
+ "Data should be"
395
+ " partitioned by row. To avoid this specify the number"
396
+ " of columns for your dask Array explicitly. e.g."
397
+ " chunks=(partition_size, -1])"
398
+ )
399
+ return client .futures_of (d )
409
400
410
- def flatten_meta (meta : Optional [_DaskCollection ]) -> Optional [List [Delayed ]]:
401
+ def flatten_meta (meta : Optional [_DaskCollection ]) -> Optional [List [Future ]]:
411
402
if meta is not None :
412
- meta_parts : List [Delayed ] = to_delayed (meta )
403
+ meta_parts : List [Future ] = to_futures (meta )
413
404
return meta_parts
414
405
return None
415
406
416
- X_parts = to_delayed (data )
407
+ X_parts = to_futures (data )
417
408
y_parts = flatten_meta (label )
418
409
w_parts = flatten_meta (weights )
419
410
margin_parts = flatten_meta (base_margin )
420
411
qid_parts = flatten_meta (qid )
421
412
ll_parts = flatten_meta (label_lower_bound )
422
413
lu_parts = flatten_meta (label_upper_bound )
423
414
424
- parts : Dict [str , List [Delayed ]] = {"data" : X_parts }
415
+ parts : Dict [str , List [Future ]] = {"data" : X_parts }
425
416
426
- def append_meta (m_parts : Optional [List [Delayed ]], name : str ) -> None :
417
+ def append_meta (m_parts : Optional [List [Future ]], name : str ) -> None :
427
418
if m_parts is not None :
428
419
assert len (X_parts ) == len (m_parts ), inconsistent (
429
420
X_parts , "X" , m_parts , name
@@ -437,12 +428,12 @@ def append_meta(m_parts: Optional[List[Delayed]], name: str) -> None:
437
428
append_meta (ll_parts , "label_lower_bound" )
438
429
append_meta (lu_parts , "label_upper_bound" )
439
430
# At this point, `parts` looks like:
440
- # [(x0, x1, ..), (y0, y1, ..), ..] in delayed form
431
+ # [(x0, x1, ..), (y0, y1, ..), ..] in future form
441
432
442
433
# turn into list of dictionaries.
443
- packed_parts : List [Dict [str , Delayed ]] = []
434
+ packed_parts : List [Dict [str , Future ]] = []
444
435
for i in range (len (X_parts )):
445
- part_dict : Dict [str , Delayed ] = {}
436
+ part_dict : Dict [str , Future ] = {}
446
437
for key , value in parts .items ():
447
438
part_dict [key ] = value [i ]
448
439
packed_parts .append (part_dict )
@@ -451,16 +442,17 @@ def append_meta(m_parts: Optional[List[Delayed]], name: str) -> None:
451
442
# pylint: disable=no-member
452
443
delayed_parts : List [Delayed ] = list (map (dask .delayed , packed_parts ))
453
444
# At this point, the mental model should look like:
454
- # [( x0, y0, ..), ( x1, y1, ..) , ..] in delayed form
445
+ # [{"data": x0, "label": y0, ..}, {"data": x1, "label": y1, ..} , ..]
455
446
456
- # convert delayed objects into futures and make sure they are realized
457
- fut_parts : List [distributed .Future ] = client .compute (delayed_parts )
447
+ # Convert delayed objects into futures and make sure they are realized
448
+ #
449
+ # This also makes partitions to align (co-locate) on workers (X_0, y_0 should be
450
+ # on the same worker).
451
+ fut_parts : List [Future ] = client .compute (delayed_parts )
458
452
await distributed .wait (fut_parts ) # async wait for parts to be computed
459
453
460
- # maybe we can call dask.align_partitions here to ease the partition alignment?
461
-
462
454
for part in fut_parts :
463
- # Each part is [x0, y0, w0 , ...] in future form.
455
+ # Each part is [{"data": x0, "label": y0, ..} , ...] in future form.
464
456
assert part .status == "finished" , part .status
465
457
466
458
# Preserving the partition order for prediction.
@@ -473,7 +465,7 @@ def append_meta(m_parts: Optional[List[Delayed]], name: str) -> None:
473
465
keys = [part .key for part in fut_parts ]
474
466
)
475
467
476
- worker_map : Dict [str , List [distributed . Future ]] = defaultdict (list )
468
+ worker_map : Dict [str , List [Future ]] = defaultdict (list )
477
469
478
470
for key , workers in who_has .items ():
479
471
worker_map [next (iter (workers ))].append (key_to_partition [key ])
0 commit comments