|
55 | 55 | import logging
|
56 | 56 | from collections import defaultdict
|
57 | 57 | from contextlib import contextmanager
|
58 |
| -from functools import partial, update_wrapper, wraps |
| 58 | +from functools import cache, partial, update_wrapper |
59 | 59 | from threading import Thread
|
60 | 60 | from typing import (
|
61 | 61 | Any,
|
|
85 | 85 | from dask import dataframe as dd
|
86 | 86 | from dask.delayed import Delayed
|
87 | 87 | from distributed import Future
|
| 88 | +from packaging.version import Version |
| 89 | +from packaging.version import parse as parse_version |
88 | 90 |
|
89 | 91 | from .. import collective, config
|
90 | 92 | from .._typing import FeatureNames, FeatureTypes, IterationRange
|
|
171 | 173 | LOGGER = logging.getLogger("[xgboost.dask]")
|
172 | 174 |
|
173 | 175 |
|
| 176 | +@cache |
| 177 | +def _DASK_VERSION() -> Version: |
| 178 | + return parse_version(dask.__version__) |
| 179 | + |
| 180 | + |
| 181 | +@cache |
| 182 | +def _DASK_2024_12_1() -> bool: |
| 183 | + return _DASK_VERSION() >= parse_version("2024.12.1") |
| 184 | + |
| 185 | + |
| 186 | +@cache |
| 187 | +def _DASK_2025_3_0() -> bool: |
| 188 | + return _DASK_VERSION() >= parse_version("2025.3.0") |
| 189 | + |
| 190 | + |
174 | 191 | def _try_start_tracker(
|
175 | 192 | n_workers: int,
|
176 | 193 | addrs: List[Union[Optional[str], Optional[Tuple[str, int]]]],
|
@@ -1476,35 +1493,33 @@ async def _predict_async(
|
1476 | 1493 | iteration_range: Optional[IterationRange],
|
1477 | 1494 | ) -> Any:
|
1478 | 1495 | iteration_range = self._get_iteration_range(iteration_range)
|
1479 |
| - if self._can_use_inplace_predict(): |
1480 |
| - predts = await inplace_predict( |
1481 |
| - client=self.client, |
1482 |
| - model=self.get_booster(), |
1483 |
| - data=data, |
1484 |
| - iteration_range=iteration_range, |
1485 |
| - predict_type="margin" if output_margin else "value", |
1486 |
| - missing=self.missing, |
1487 |
| - base_margin=base_margin, |
1488 |
| - validate_features=validate_features, |
1489 |
| - ) |
1490 |
| - if isinstance(predts, dd.DataFrame): |
1491 |
| - predts = predts.to_dask_array() |
1492 |
| - else: |
1493 |
| - test_dmatrix: DaskDMatrix = await DaskDMatrix( # type: ignore |
1494 |
| - self.client, |
1495 |
| - data=data, |
1496 |
| - base_margin=base_margin, |
1497 |
| - missing=self.missing, |
1498 |
| - feature_types=self.feature_types, |
1499 |
| - ) |
1500 |
| - predts = await predict( |
1501 |
| - self.client, |
1502 |
| - model=self.get_booster(), |
1503 |
| - data=test_dmatrix, |
1504 |
| - output_margin=output_margin, |
1505 |
| - validate_features=validate_features, |
1506 |
| - iteration_range=iteration_range, |
1507 |
| - ) |
| 1496 | + # Dask doesn't support gblinear and accepts only Dask collection types (array |
| 1497 | + # and dataframe). We can perform inplace predict. |
| 1498 | + assert self._can_use_inplace_predict() |
| 1499 | + predts = await inplace_predict( |
| 1500 | + client=self.client, |
| 1501 | + model=self.get_booster(), |
| 1502 | + data=data, |
| 1503 | + iteration_range=iteration_range, |
| 1504 | + predict_type="margin" if output_margin else "value", |
| 1505 | + missing=self.missing, |
| 1506 | + base_margin=base_margin, |
| 1507 | + validate_features=validate_features, |
| 1508 | + ) |
| 1509 | + if isinstance(predts, dd.DataFrame): |
| 1510 | + predts = predts.to_dask_array() |
| 1511 | + # Make sure the booster is part of the task graph implicitly |
| 1512 | + # only needed for certain versions of dask. |
| 1513 | + if _DASK_2024_12_1() and not _DASK_2025_3_0(): |
| 1514 | + # Fixes this issue for dask>=2024.1.1,<2025.3.0 |
| 1515 | + # Dask==2025.3.0 fails with: |
| 1516 | + # RuntimeError: Attempting to use an asynchronous |
| 1517 | + # Client in a synchronous context of `dask.compute` |
| 1518 | + # |
| 1519 | + # Dask==2025.4.0 fails with: |
| 1520 | + # TypeError: Value type is not supported for data |
| 1521 | + # iterator:<class 'distributed.client.Future'> |
| 1522 | + predts = predts.persist() |
1508 | 1523 | return predts
|
1509 | 1524 |
|
1510 | 1525 | @_deprecate_positional_args
|
|
0 commit comments