Skip to content

Commit 5cdd05e

Browse files
[backport][dask] Workarounds for different Dask versions. (dmlc#11436) (dmlc#11437)
--------- Co-authored-by: TomAugspurger <[email protected]>
1 parent 10f5f6d commit 5cdd05e

File tree

1 file changed

+45
-30
lines changed

1 file changed

+45
-30
lines changed

python-package/xgboost/dask/__init__.py

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
import logging
5656
from collections import defaultdict
5757
from contextlib import contextmanager
58-
from functools import partial, update_wrapper, wraps
58+
from functools import cache, partial, update_wrapper
5959
from threading import Thread
6060
from typing import (
6161
Any,
@@ -85,6 +85,8 @@
8585
from dask import dataframe as dd
8686
from dask.delayed import Delayed
8787
from distributed import Future
88+
from packaging.version import Version
89+
from packaging.version import parse as parse_version
8890

8991
from .. import collective, config
9092
from .._typing import FeatureNames, FeatureTypes, IterationRange
@@ -171,6 +173,21 @@
171173
LOGGER = logging.getLogger("[xgboost.dask]")
172174

173175

176+
@cache
177+
def _DASK_VERSION() -> Version:
178+
return parse_version(dask.__version__)
179+
180+
181+
@cache
182+
def _DASK_2024_12_1() -> bool:
183+
return _DASK_VERSION() >= parse_version("2024.12.1")
184+
185+
186+
@cache
187+
def _DASK_2025_3_0() -> bool:
188+
return _DASK_VERSION() >= parse_version("2025.3.0")
189+
190+
174191
def _try_start_tracker(
175192
n_workers: int,
176193
addrs: List[Union[Optional[str], Optional[Tuple[str, int]]]],
@@ -1476,35 +1493,33 @@ async def _predict_async(
14761493
iteration_range: Optional[IterationRange],
14771494
) -> Any:
14781495
iteration_range = self._get_iteration_range(iteration_range)
1479-
if self._can_use_inplace_predict():
1480-
predts = await inplace_predict(
1481-
client=self.client,
1482-
model=self.get_booster(),
1483-
data=data,
1484-
iteration_range=iteration_range,
1485-
predict_type="margin" if output_margin else "value",
1486-
missing=self.missing,
1487-
base_margin=base_margin,
1488-
validate_features=validate_features,
1489-
)
1490-
if isinstance(predts, dd.DataFrame):
1491-
predts = predts.to_dask_array()
1492-
else:
1493-
test_dmatrix: DaskDMatrix = await DaskDMatrix( # type: ignore
1494-
self.client,
1495-
data=data,
1496-
base_margin=base_margin,
1497-
missing=self.missing,
1498-
feature_types=self.feature_types,
1499-
)
1500-
predts = await predict(
1501-
self.client,
1502-
model=self.get_booster(),
1503-
data=test_dmatrix,
1504-
output_margin=output_margin,
1505-
validate_features=validate_features,
1506-
iteration_range=iteration_range,
1507-
)
1496+
# Dask doesn't support gblinear and accepts only Dask collection types (array
1497+
# and dataframe). We can perform inplace predict.
1498+
assert self._can_use_inplace_predict()
1499+
predts = await inplace_predict(
1500+
client=self.client,
1501+
model=self.get_booster(),
1502+
data=data,
1503+
iteration_range=iteration_range,
1504+
predict_type="margin" if output_margin else "value",
1505+
missing=self.missing,
1506+
base_margin=base_margin,
1507+
validate_features=validate_features,
1508+
)
1509+
if isinstance(predts, dd.DataFrame):
1510+
predts = predts.to_dask_array()
1511+
# Make sure the booster is part of the task graph implicitly
1512+
# only needed for certain versions of dask.
1513+
if _DASK_2024_12_1() and not _DASK_2025_3_0():
1514+
# Fixes this issue for dask>=2024.1.1,<2025.3.0
1515+
# Dask==2025.3.0 fails with:
1516+
# RuntimeError: Attempting to use an asynchronous
1517+
# Client in a synchronous context of `dask.compute`
1518+
#
1519+
# Dask==2025.4.0 fails with:
1520+
# TypeError: Value type is not supported for data
1521+
# iterator:<class 'distributed.client.Future'>
1522+
predts = predts.persist()
15081523
return predts
15091524

15101525
@_deprecate_positional_args

0 commit comments

Comments
 (0)