Skip to content

Commit 68be454

Browse files
wbo4958trivialfis
andauthored
[pyspark] hotfix for GPU setup validation (dmlc#9495)
* [pyspark] fix a bug of validating gpu configuration --------- Co-authored-by: Jiaming Yuan <[email protected]>
1 parent 5188e27 commit 68be454

File tree

1 file changed

+26
-20
lines changed
  • python-package/xgboost/spark

1 file changed

+26
-20
lines changed

python-package/xgboost/spark/core.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -424,35 +424,41 @@ def _validate_params(self) -> None:
424424

425425
if is_local:
426426
# checking spark local mode.
427-
if gpu_per_task:
427+
if gpu_per_task is not None:
428428
raise RuntimeError(
429-
"The spark cluster does not support gpu configuration for local mode. "
430-
"Please delete spark.executor.resource.gpu.amount and "
429+
"The spark local mode does not support gpu configuration."
430+
"Please remove spark.executor.resource.gpu.amount and "
431431
"spark.task.resource.gpu.amount"
432432
)
433433

434-
# Support GPU training in Spark local mode is just for debugging purposes,
435-
# so it's okay for printing the below warning instead of checking the real
436-
# gpu numbers and raising the exception.
434+
# Support GPU training in Spark local mode is just for debugging
435+
# purposes, so it's okay for printing the below warning instead of
436+
# checking the real gpu numbers and raising the exception.
437437
get_logger(self.__class__.__name__).warning(
438-
"You enabled GPU in spark local mode. Please make sure your local "
439-
"node has at least %d GPUs",
438+
"You have enabled GPU in spark local mode. Please make sure your"
439+
" local node has at least %d GPUs",
440440
self.getOrDefault(self.num_workers),
441441
)
442442
else:
443443
# checking spark non-local mode.
444-
if not gpu_per_task or int(gpu_per_task) < 1:
445-
raise RuntimeError(
446-
"The spark cluster does not have the necessary GPU"
447-
+ "configuration for the spark task. Therefore, we cannot"
448-
+ "run xgboost training using GPU."
449-
)
450-
451-
if int(gpu_per_task) > 1:
452-
get_logger(self.__class__.__name__).warning(
453-
"You configured %s GPU cores for each spark task, but in "
454-
"XGBoost training, every Spark task will only use one GPU core.",
455-
gpu_per_task,
444+
if gpu_per_task is not None:
445+
if float(gpu_per_task) < 1.0:
446+
raise ValueError(
447+
"XGBoost doesn't support GPU fractional configurations. "
448+
"Please set `spark.task.resource.gpu.amount=spark.executor"
449+
".resource.gpu.amount`"
450+
)
451+
452+
if float(gpu_per_task) > 1.0:
453+
get_logger(self.__class__.__name__).warning(
454+
"%s GPUs for each Spark task is configured, but each "
455+
"XGBoost training task uses only 1 GPU.",
456+
gpu_per_task,
457+
)
458+
else:
459+
raise ValueError(
460+
"The `spark.task.resource.gpu.amount` is required for training"
461+
" on GPU."
456462
)
457463

458464

0 commit comments

Comments
 (0)