|
55 | 55 | import logging |
56 | 56 | from collections import defaultdict |
57 | 57 | from contextlib import contextmanager |
58 | | -from functools import partial, update_wrapper, wraps |
| 58 | +from functools import cache, partial, update_wrapper |
59 | 59 | from threading import Thread |
60 | 60 | from typing import ( |
61 | 61 | Any, |
|
85 | 85 | from dask import dataframe as dd |
86 | 86 | from dask.delayed import Delayed |
87 | 87 | from distributed import Future |
| 88 | +from packaging.version import Version |
| 89 | +from packaging.version import parse as parse_version |
88 | 90 |
|
89 | 91 | from .. import collective, config |
90 | 92 | from .._typing import FeatureNames, FeatureTypes, IterationRange |
|
171 | 173 | LOGGER = logging.getLogger("[xgboost.dask]") |
172 | 174 |
|
173 | 175 |
|
| 176 | +@cache |
| 177 | +def _DASK_VERSION() -> Version: |
| 178 | + return parse_version(dask.__version__) |
| 179 | + |
| 180 | + |
| 181 | +@cache |
| 182 | +def _DASK_2024_12_1() -> bool: |
| 183 | + return _DASK_VERSION() >= parse_version("2024.12.1") |
| 184 | + |
| 185 | + |
| 186 | +@cache |
| 187 | +def _DASK_2025_3_0() -> bool: |
| 188 | + return _DASK_VERSION() >= parse_version("2025.3.0") |
| 189 | + |
| 190 | + |
174 | 191 | def _try_start_tracker( |
175 | 192 | n_workers: int, |
176 | 193 | addrs: List[Union[Optional[str], Optional[Tuple[str, int]]]], |
@@ -1476,35 +1493,33 @@ async def _predict_async( |
1476 | 1493 | iteration_range: Optional[IterationRange], |
1477 | 1494 | ) -> Any: |
1478 | 1495 | iteration_range = self._get_iteration_range(iteration_range) |
1479 | | - if self._can_use_inplace_predict(): |
1480 | | - predts = await inplace_predict( |
1481 | | - client=self.client, |
1482 | | - model=self.get_booster(), |
1483 | | - data=data, |
1484 | | - iteration_range=iteration_range, |
1485 | | - predict_type="margin" if output_margin else "value", |
1486 | | - missing=self.missing, |
1487 | | - base_margin=base_margin, |
1488 | | - validate_features=validate_features, |
1489 | | - ) |
1490 | | - if isinstance(predts, dd.DataFrame): |
1491 | | - predts = predts.to_dask_array() |
1492 | | - else: |
1493 | | - test_dmatrix: DaskDMatrix = await DaskDMatrix( # type: ignore |
1494 | | - self.client, |
1495 | | - data=data, |
1496 | | - base_margin=base_margin, |
1497 | | - missing=self.missing, |
1498 | | - feature_types=self.feature_types, |
1499 | | - ) |
1500 | | - predts = await predict( |
1501 | | - self.client, |
1502 | | - model=self.get_booster(), |
1503 | | - data=test_dmatrix, |
1504 | | - output_margin=output_margin, |
1505 | | - validate_features=validate_features, |
1506 | | - iteration_range=iteration_range, |
1507 | | - ) |
| 1496 | + # Dask doesn't support gblinear and accepts only Dask collection types (array |
| 1497 | + # and dataframe). We can perform inplace predict. |
| 1498 | + assert self._can_use_inplace_predict() |
| 1499 | + predts = await inplace_predict( |
| 1500 | + client=self.client, |
| 1501 | + model=self.get_booster(), |
| 1502 | + data=data, |
| 1503 | + iteration_range=iteration_range, |
| 1504 | + predict_type="margin" if output_margin else "value", |
| 1505 | + missing=self.missing, |
| 1506 | + base_margin=base_margin, |
| 1507 | + validate_features=validate_features, |
| 1508 | + ) |
| 1509 | + if isinstance(predts, dd.DataFrame): |
| 1510 | + predts = predts.to_dask_array() |
| 1511 | + # Make sure the booster is part of the task graph implicitly |
| 1512 | + # only needed for certain versions of dask. |
| 1513 | + if _DASK_2024_12_1() and not _DASK_2025_3_0(): |
| 1514 | + # Fixes this issue for dask>=2024.1.1,<2025.3.0 |
| 1515 | + # Dask==2025.3.0 fails with: |
| 1516 | + # RuntimeError: Attempting to use an asynchronous |
| 1517 | + # Client in a synchronous context of `dask.compute` |
| 1518 | + # |
| 1519 | + # Dask==2025.4.0 fails with: |
| 1520 | + # TypeError: Value type is not supported for data |
| 1521 | + # iterator:<class 'distributed.client.Future'> |
| 1522 | + predts = predts.persist() |
1508 | 1523 | return predts |
1509 | 1524 |
|
1510 | 1525 | @_deprecate_positional_args |
|
0 commit comments