8383from dask import array as da
8484from dask import bag as db
8585from dask import dataframe as dd
86+ from dask .delayed import Delayed
87+ from distributed import Future
8688
8789from .. import collective , config
8890from .._typing import FeatureNames , FeatureTypes , IterationRange
@@ -336,7 +338,7 @@ def __init__(
336338
337339 self ._n_cols = data .shape [1 ]
338340 assert isinstance (self ._n_cols , int )
339- self .worker_map : Dict [str , List [distributed . Future ]] = defaultdict (list )
341+ self .worker_map : Dict [str , List [Future ]] = defaultdict (list )
340342 self .is_quantile : bool = False
341343
342344 self ._init = client .sync (
@@ -369,7 +371,6 @@ async def _map_local_data(
369371 label_upper_bound : Optional [_DaskCollection ] = None ,
370372 ) -> "DaskDMatrix" :
371373 """Obtain references to local data."""
372- from dask .delayed import Delayed
373374
374375 def inconsistent (
375376 left : List [Any ], left_name : str , right : List [Any ], right_name : str
@@ -381,49 +382,39 @@ def inconsistent(
381382 )
382383 return msg
383384
384- def check_columns (parts : numpy .ndarray ) -> None :
385- # x is required to be 2 dim in __init__
386- assert parts .ndim == 1 or parts .shape [1 ], (
387- "Data should be"
388- " partitioned by row. To avoid this specify the number"
389- " of columns for your dask Array explicitly. e.g."
390- " chunks=(partition_size, X.shape[1])"
391- )
392-
393- def to_delayed (d : _DaskCollection ) -> List [Delayed ]:
394- """Breaking data into partitions, a trick borrowed from
395- dask_xgboost. `to_delayed` downgrades high-level objects into numpy or
396- pandas equivalents.
397-
398- """
385+ def to_futures (d : _DaskCollection ) -> List [Future ]:
386+ """Breaking data into partitions."""
399387 d = client .persist (d )
400- delayed_obj = d .to_delayed ()
401- if isinstance (delayed_obj , numpy .ndarray ):
402- # da.Array returns an array to delayed objects
403- check_columns (delayed_obj )
404- delayed_list : List [Delayed ] = delayed_obj .flatten ().tolist ()
405- else :
406- # dd.DataFrame
407- delayed_list = delayed_obj
408- return delayed_list
388+ if (
389+ hasattr (d .partitions , "shape" )
390+ and len (d .partitions .shape ) > 1
391+ and d .partitions .shape [1 ] > 1
392+ ):
393+ raise ValueError (
394+ "Data should be"
395+ " partitioned by row. To avoid this specify the number"
396+ " of columns for your dask Array explicitly. e.g."
397+ " chunks=(partition_size, -1])"
398+ )
399+ return client .futures_of (d )
409400
410- def flatten_meta (meta : Optional [_DaskCollection ]) -> Optional [List [Delayed ]]:
401+ def flatten_meta (meta : Optional [_DaskCollection ]) -> Optional [List [Future ]]:
411402 if meta is not None :
412- meta_parts : List [Delayed ] = to_delayed (meta )
403+ meta_parts : List [Future ] = to_futures (meta )
413404 return meta_parts
414405 return None
415406
416- X_parts = to_delayed (data )
407+ X_parts = to_futures (data )
417408 y_parts = flatten_meta (label )
418409 w_parts = flatten_meta (weights )
419410 margin_parts = flatten_meta (base_margin )
420411 qid_parts = flatten_meta (qid )
421412 ll_parts = flatten_meta (label_lower_bound )
422413 lu_parts = flatten_meta (label_upper_bound )
423414
424- parts : Dict [str , List [Delayed ]] = {"data" : X_parts }
415+ parts : Dict [str , List [Future ]] = {"data" : X_parts }
425416
426- def append_meta (m_parts : Optional [List [Delayed ]], name : str ) -> None :
417+ def append_meta (m_parts : Optional [List [Future ]], name : str ) -> None :
427418 if m_parts is not None :
428419 assert len (X_parts ) == len (m_parts ), inconsistent (
429420 X_parts , "X" , m_parts , name
@@ -437,12 +428,12 @@ def append_meta(m_parts: Optional[List[Delayed]], name: str) -> None:
437428 append_meta (ll_parts , "label_lower_bound" )
438429 append_meta (lu_parts , "label_upper_bound" )
439430 # At this point, `parts` looks like:
440- # [(x0, x1, ..), (y0, y1, ..), ..] in delayed form
431+ # [(x0, x1, ..), (y0, y1, ..), ..] in future form
441432
442433 # turn into list of dictionaries.
443- packed_parts : List [Dict [str , Delayed ]] = []
434+ packed_parts : List [Dict [str , Future ]] = []
444435 for i in range (len (X_parts )):
445- part_dict : Dict [str , Delayed ] = {}
436+ part_dict : Dict [str , Future ] = {}
446437 for key , value in parts .items ():
447438 part_dict [key ] = value [i ]
448439 packed_parts .append (part_dict )
@@ -451,16 +442,17 @@ def append_meta(m_parts: Optional[List[Delayed]], name: str) -> None:
451442 # pylint: disable=no-member
452443 delayed_parts : List [Delayed ] = list (map (dask .delayed , packed_parts ))
453444 # At this point, the mental model should look like:
454- # [( x0, y0, ..), ( x1, y1, ..) , ..] in delayed form
445+ # [{"data": x0, "label": y0, ..}, {"data": x1, "label": y1, ..} , ..]
455446
456- # convert delayed objects into futures and make sure they are realized
457- fut_parts : List [distributed .Future ] = client .compute (delayed_parts )
447+ # Convert delayed objects into futures and make sure they are realized
448+ #
449+ # This also makes partitions to align (co-locate) on workers (X_0, y_0 should be
450+ # on the same worker).
451+ fut_parts : List [Future ] = client .compute (delayed_parts )
458452 await distributed .wait (fut_parts ) # async wait for parts to be computed
459453
460- # maybe we can call dask.align_partitions here to ease the partition alignment?
461-
462454 for part in fut_parts :
463- # Each part is [x0, y0, w0 , ...] in future form.
455+ # Each part is [{"data": x0, "label": y0, ..} , ...] in future form.
464456 assert part .status == "finished" , part .status
465457
466458 # Preserving the partition order for prediction.
@@ -473,7 +465,7 @@ def append_meta(m_parts: Optional[List[Delayed]], name: str) -> None:
473465 keys = [part .key for part in fut_parts ]
474466 )
475467
476- worker_map : Dict [str , List [distributed . Future ]] = defaultdict (list )
468+ worker_map : Dict [str , List [Future ]] = defaultdict (list )
477469
478470 for key , workers in who_has .items ():
479471 worker_map [next (iter (workers ))].append (key_to_partition [key ])
0 commit comments