File tree Expand file tree Collapse file tree 2 files changed +7
-10
lines changed
Expand file tree Collapse file tree 2 files changed +7
-10
lines changed Original file line number Diff line number Diff line change @@ -254,15 +254,11 @@ def __init__(self, **args: CollArgsVals) -> None:
254254 super ().__init__ (** args )
255255
256256 worker = distributed .get_worker ()
257- with distributed .worker_client () as client :
258- info = client .scheduler_info ()
259- w = info ["workers" ][worker .address ]
260- wid = w ["id" ]
261257 # We use task ID for rank assignment which makes the RABIT rank consistent (but
262258 # not the same as task ID is string and "10" is sorted before "2") with dask
263- # worker ID . This outsources the rank assignment to dask and prevents
259+ # worker name . This outsources the rank assignment to dask and prevents
264260 # non-deterministic issue.
265- self .args ["DMLC_TASK_ID" ] = f"[xgboost.dask-{ wid } ]:" + str ( worker .address )
261+ self .args ["DMLC_TASK_ID" ] = f"[xgboost.dask-{ worker . name } ]:{ worker .address } "
266262
267263
268264def _get_client (client : Optional ["distributed.Client" ]) -> "distributed.Client" :
@@ -923,12 +919,11 @@ def train( # pylint: disable=unused-argument
923919
924920 """
925921 client = _get_client (client )
926- args = locals ()
927922 return client .sync (
928923 _train_async ,
929924 global_config = config .get_config (),
930925 dconfig = _get_dask_config (),
931- ** args ,
926+ ** locals () ,
932927 )
933928
934929
Original file line number Diff line number Diff line change 77from dask import array as da
88from dask import dataframe as dd
99from distributed import Client , get_worker
10+ from packaging .version import parse as parse_version
1011from sklearn .datasets import make_classification
1112
1213import xgboost as xgb
1516from xgboost .testing .updater import get_basescore
1617
1718from .. import dask as dxgb
18- from ..dask import _get_rabit_args
19+ from ..dask import _DASK_VERSION , _get_rabit_args
1920
2021
2122def check_init_estimation_clf (
@@ -177,7 +178,8 @@ def get_rabit_args(client: Client, n_workers: int) -> Any:
177178
178179def get_client_workers (client : Client ) -> List [str ]:
179180 "Get workers from a dask client."
180- workers = client .scheduler_info ()["workers" ]
181+ kwargs = {"n_workers" : - 1 } if _DASK_VERSION () >= parse_version ("2025.4.0" ) else {}
182+ workers = client .scheduler_info (** kwargs )["workers" ]
181183 return list (workers .keys ())
182184
183185
You can’t perform that action at this time.
0 commit comments