Skip to content

Commit 3fd8451

Browse files
authored
[backport][dask] Disable broadcast in the scatter call. (dmlc#10632) (dmlc#10634)
1 parent 7643306 commit 3fd8451

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

python-package/xgboost/dask/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,12 +1237,14 @@ def _infer_predict_output(
12371237
async def _get_model_future(
12381238
client: "distributed.Client", model: Union[Booster, Dict, "distributed.Future"]
12391239
) -> "distributed.Future":
1240-
# See https://github.com/dask/dask/issues/11179#issuecomment-2168094529 for
1241-
# the use of hash.
1240+
# See https://github.com/dask/dask/issues/11179#issuecomment-2168094529 for the use
1241+
# of hash.
1242+
# https://github.com/dask/distributed/pull/8796 Don't use broadcast in the `scatter`
1243+
# call, otherwise, the predict function might hang.
12421244
if isinstance(model, Booster):
1243-
booster = await client.scatter(model, broadcast=True, hash=False)
1245+
booster = await client.scatter(model, hash=False)
12441246
elif isinstance(model, dict):
1245-
booster = await client.scatter(model["booster"], broadcast=True, hash=False)
1247+
booster = await client.scatter(model["booster"], hash=False)
12461248
elif isinstance(model, distributed.Future):
12471249
booster = model
12481250
t = booster.type

0 commit comments

Comments
 (0)