@@ -424,35 +424,41 @@ def _validate_params(self) -> None:
424
424
425
425
if is_local :
426
426
# checking spark local mode.
427
- if gpu_per_task :
427
+ if gpu_per_task is not None :
428
428
raise RuntimeError (
429
- "The spark cluster does not support gpu configuration for local mode. "
430
- "Please delete spark.executor.resource.gpu.amount and "
429
+ "The spark local mode does not support gpu configuration. "
430
+ "Please remove spark.executor.resource.gpu.amount and "
431
431
"spark.task.resource.gpu.amount"
432
432
)
433
433
434
- # Support GPU training in Spark local mode is just for debugging purposes,
435
- # so it's okay for printing the below warning instead of checking the real
436
- # gpu numbers and raising the exception.
434
+ # Support GPU training in Spark local mode is just for debugging
435
+ # purposes, so it's okay for printing the below warning instead of
436
+ # checking the real gpu numbers and raising the exception.
437
437
get_logger (self .__class__ .__name__ ).warning (
438
- "You enabled GPU in spark local mode. Please make sure your local "
439
- "node has at least %d GPUs" ,
438
+ "You have enabled GPU in spark local mode. Please make sure your"
439
+ " local node has at least %d GPUs" ,
440
440
self .getOrDefault (self .num_workers ),
441
441
)
442
442
else :
443
443
# checking spark non-local mode.
444
- if not gpu_per_task or int (gpu_per_task ) < 1 :
445
- raise RuntimeError (
446
- "The spark cluster does not have the necessary GPU"
447
- + "configuration for the spark task. Therefore, we cannot"
448
- + "run xgboost training using GPU."
449
- )
450
-
451
- if int (gpu_per_task ) > 1 :
452
- get_logger (self .__class__ .__name__ ).warning (
453
- "You configured %s GPU cores for each spark task, but in "
454
- "XGBoost training, every Spark task will only use one GPU core." ,
455
- gpu_per_task ,
444
+ if gpu_per_task is not None :
445
+ if float (gpu_per_task ) < 1.0 :
446
+ raise ValueError (
447
+ "XGBoost doesn't support GPU fractional configurations. "
448
+ "Please set `spark.task.resource.gpu.amount=spark.executor"
449
+ ".resource.gpu.amount`"
450
+ )
451
+
452
+ if float (gpu_per_task ) > 1.0 :
453
+ get_logger (self .__class__ .__name__ ).warning (
454
+ "%s GPUs for each Spark task is configured, but each "
455
+ "XGBoost training task uses only 1 GPU." ,
456
+ gpu_per_task ,
457
+ )
458
+ else :
459
+ raise ValueError (
460
+ "The `spark.task.resource.gpu.amount` is required for training"
461
+ " on GPU."
456
462
)
457
463
458
464
0 commit comments