Skip to content

Commit d6042e4

Browse files
authored
[backport][dask] Fix with the latest Dask. (dmlc#11291) (dmlc#11302)
1 parent 532318d commit d6042e4

File tree

3 files changed

+40
-46
lines changed

3 files changed

+40
-46
lines changed

python-package/xgboost/dask/__init__.py

Lines changed: 33 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@
8383
from dask import array as da
8484
from dask import bag as db
8585
from dask import dataframe as dd
86+
from dask.delayed import Delayed
87+
from distributed import Future
8688

8789
from .. import collective, config
8890
from .._typing import FeatureNames, FeatureTypes, IterationRange
@@ -336,7 +338,7 @@ def __init__(
336338

337339
self._n_cols = data.shape[1]
338340
assert isinstance(self._n_cols, int)
339-
self.worker_map: Dict[str, List[distributed.Future]] = defaultdict(list)
341+
self.worker_map: Dict[str, List[Future]] = defaultdict(list)
340342
self.is_quantile: bool = False
341343

342344
self._init = client.sync(
@@ -369,7 +371,6 @@ async def _map_local_data(
369371
label_upper_bound: Optional[_DaskCollection] = None,
370372
) -> "DaskDMatrix":
371373
"""Obtain references to local data."""
372-
from dask.delayed import Delayed
373374

374375
def inconsistent(
375376
left: List[Any], left_name: str, right: List[Any], right_name: str
@@ -381,49 +382,39 @@ def inconsistent(
381382
)
382383
return msg
383384

384-
def check_columns(parts: numpy.ndarray) -> None:
385-
# x is required to be 2 dim in __init__
386-
assert parts.ndim == 1 or parts.shape[1], (
387-
"Data should be"
388-
" partitioned by row. To avoid this specify the number"
389-
" of columns for your dask Array explicitly. e.g."
390-
" chunks=(partition_size, X.shape[1])"
391-
)
392-
393-
def to_delayed(d: _DaskCollection) -> List[Delayed]:
394-
"""Breaking data into partitions, a trick borrowed from
395-
dask_xgboost. `to_delayed` downgrades high-level objects into numpy or
396-
pandas equivalents.
397-
398-
"""
385+
def to_futures(d: _DaskCollection) -> List[Future]:
386+
"""Breaking data into partitions."""
399387
d = client.persist(d)
400-
delayed_obj = d.to_delayed()
401-
if isinstance(delayed_obj, numpy.ndarray):
402-
# da.Array returns an array to delayed objects
403-
check_columns(delayed_obj)
404-
delayed_list: List[Delayed] = delayed_obj.flatten().tolist()
405-
else:
406-
# dd.DataFrame
407-
delayed_list = delayed_obj
408-
return delayed_list
388+
if (
389+
hasattr(d.partitions, "shape")
390+
and len(d.partitions.shape) > 1
391+
and d.partitions.shape[1] > 1
392+
):
393+
raise ValueError(
394+
"Data should be"
395+
" partitioned by row. To avoid this specify the number"
396+
" of columns for your dask Array explicitly. e.g."
397+
" chunks=(partition_size, -1])"
398+
)
399+
return client.futures_of(d)
409400

410-
def flatten_meta(meta: Optional[_DaskCollection]) -> Optional[List[Delayed]]:
401+
def flatten_meta(meta: Optional[_DaskCollection]) -> Optional[List[Future]]:
411402
if meta is not None:
412-
meta_parts: List[Delayed] = to_delayed(meta)
403+
meta_parts: List[Future] = to_futures(meta)
413404
return meta_parts
414405
return None
415406

416-
X_parts = to_delayed(data)
407+
X_parts = to_futures(data)
417408
y_parts = flatten_meta(label)
418409
w_parts = flatten_meta(weights)
419410
margin_parts = flatten_meta(base_margin)
420411
qid_parts = flatten_meta(qid)
421412
ll_parts = flatten_meta(label_lower_bound)
422413
lu_parts = flatten_meta(label_upper_bound)
423414

424-
parts: Dict[str, List[Delayed]] = {"data": X_parts}
415+
parts: Dict[str, List[Future]] = {"data": X_parts}
425416

426-
def append_meta(m_parts: Optional[List[Delayed]], name: str) -> None:
417+
def append_meta(m_parts: Optional[List[Future]], name: str) -> None:
427418
if m_parts is not None:
428419
assert len(X_parts) == len(m_parts), inconsistent(
429420
X_parts, "X", m_parts, name
@@ -437,12 +428,12 @@ def append_meta(m_parts: Optional[List[Delayed]], name: str) -> None:
437428
append_meta(ll_parts, "label_lower_bound")
438429
append_meta(lu_parts, "label_upper_bound")
439430
# At this point, `parts` looks like:
440-
# [(x0, x1, ..), (y0, y1, ..), ..] in delayed form
431+
# [(x0, x1, ..), (y0, y1, ..), ..] in future form
441432

442433
# turn into list of dictionaries.
443-
packed_parts: List[Dict[str, Delayed]] = []
434+
packed_parts: List[Dict[str, Future]] = []
444435
for i in range(len(X_parts)):
445-
part_dict: Dict[str, Delayed] = {}
436+
part_dict: Dict[str, Future] = {}
446437
for key, value in parts.items():
447438
part_dict[key] = value[i]
448439
packed_parts.append(part_dict)
@@ -451,16 +442,17 @@ def append_meta(m_parts: Optional[List[Delayed]], name: str) -> None:
451442
# pylint: disable=no-member
452443
delayed_parts: List[Delayed] = list(map(dask.delayed, packed_parts))
453444
# At this point, the mental model should look like:
454-
# [(x0, y0, ..), (x1, y1, ..), ..] in delayed form
445+
# [{"data": x0, "label": y0, ..}, {"data": x1, "label": y1, ..}, ..]
455446

456-
# convert delayed objects into futures and make sure they are realized
457-
fut_parts: List[distributed.Future] = client.compute(delayed_parts)
447+
# Convert delayed objects into futures and make sure they are realized
448+
#
449+
# This also makes partitions to align (co-locate) on workers (X_0, y_0 should be
450+
# on the same worker).
451+
fut_parts: List[Future] = client.compute(delayed_parts)
458452
await distributed.wait(fut_parts) # async wait for parts to be computed
459453

460-
# maybe we can call dask.align_partitions here to ease the partition alignment?
461-
462454
for part in fut_parts:
463-
# Each part is [x0, y0, w0, ...] in future form.
455+
# Each part is [{"data": x0, "label": y0, ..}, ...] in future form.
464456
assert part.status == "finished", part.status
465457

466458
# Preserving the partition order for prediction.
@@ -473,7 +465,7 @@ def append_meta(m_parts: Optional[List[Delayed]], name: str) -> None:
473465
keys=[part.key for part in fut_parts]
474466
)
475467

476-
worker_map: Dict[str, List[distributed.Future]] = defaultdict(list)
468+
worker_map: Dict[str, List[Future]] = defaultdict(list)
477469

478470
for key, workers in who_has.items():
479471
worker_map[next(iter(workers))].append(key_to_partition[key])

src/c_api/c_api.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ void XGBBuildInfoDevice(Json *p_info) {
7676
#endif
7777

7878
XGB_DLL int XGBuildInfo(char const **out) {
79-
API_BEGIN();
79+
API_BEGIN_UNGUARD()
8080
xgboost_CHECK_C_ARG_PTR(out);
8181
Json info{Object{}};
8282

@@ -135,14 +135,14 @@ XGB_DLL int XGBuildInfo(char const **out) {
135135
}
136136

137137
XGB_DLL int XGBRegisterLogCallback(void (*callback)(const char*)) {
138-
API_BEGIN_UNGUARD();
138+
API_BEGIN_UNGUARD()
139139
LogCallbackRegistry* registry = LogCallbackRegistryStore::Get();
140140
registry->Register(callback);
141141
API_END();
142142
}
143143

144144
XGB_DLL int XGBSetGlobalConfig(const char* json_str) {
145-
API_BEGIN();
145+
API_BEGIN_UNGUARD()
146146

147147
xgboost_CHECK_C_ARG_PTR(json_str);
148148
Json config{Json::Load(StringView{json_str})};
@@ -204,7 +204,7 @@ XGB_DLL int XGBSetGlobalConfig(const char* json_str) {
204204
}
205205

206206
XGB_DLL int XGBGetGlobalConfig(const char** json_str) {
207-
API_BEGIN();
207+
API_BEGIN_UNGUARD()
208208
auto const& global_config = *GlobalConfigThreadLocalStore::Get();
209209
Json config {ToJson(global_config)};
210210
auto const* mgr = global_config.__MANAGER__();
@@ -246,6 +246,8 @@ XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle
246246
xgboost_CHECK_C_ARG_PTR(fname);
247247
xgboost_CHECK_C_ARG_PTR(out);
248248

249+
LOG(WARNING) << error::DeprecatedFunc(__func__, "2.0.0", "XGDMatrixCreateFromURI");
250+
249251
Json config{Object()};
250252
config["uri"] = std::string{fname};
251253
config["silent"] = silent;

tests/python-gpu/test_gpu_prediction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def predict_dense(x):
287287
def test_inplace_predict_cupy(self):
288288
self.run_inplace_predict_cupy(0)
289289

290-
@pytest.mark.skipif(**tm.no_cupy())
290+
@pytest.mark.skip
291291
@pytest.mark.mgpu
292292
def test_inplace_predict_cupy_specified_device(self):
293293
import cupy as cp

0 commit comments

Comments
 (0)