Skip to content

Commit 344f90b

Browse files
wbo4958trivialfis
andauthored
[jvm-packages] throw exception when tree_method=approx and device=cuda (dmlc#9478)
--------- Co-authored-by: Jiaming Yuan <[email protected]>
1 parent 05d7000 commit 344f90b

File tree

3 files changed

+47
-28
lines changed

3 files changed

+47
-28
lines changed

jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,14 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
9393

9494
private val overridedParams = overrideParams(rawParams, sc)
9595

96+
validateSparkSslConf()
97+
9698
/**
9799
* Check to see if Spark expects SSL encryption (`spark.ssl.enabled` set to true).
98100
* If so, throw an exception unless this safety measure has been explicitly overridden
99101
* via conf `xgboost.spark.ignoreSsl`.
100102
*/
101-
private def validateSparkSslConf: Unit = {
103+
private def validateSparkSslConf(): Unit = {
102104
val (sparkSslEnabled: Boolean, xgboostSparkIgnoreSsl: Boolean) =
103105
SparkSession.getActiveSession match {
104106
case Some(ss) =>
@@ -148,55 +150,59 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
148150
overridedParams
149151
}
150152

153+
/**
154+
* The Map parameters accepted by estimator's constructor may have string type,
155+
* Eg, Map("num_workers" -> "6", "num_round" -> 5), we need to convert these
156+
* kind of parameters into the correct type in the function.
157+
*
158+
* @return XGBoostExecutionParams
159+
*/
151160
def buildXGBRuntimeParams: XGBoostExecutionParams = {
161+
162+
val obj = overridedParams.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait]
163+
val eval = overridedParams.getOrElse("custom_eval", null).asInstanceOf[EvalTrait]
164+
if (obj != null) {
165+
require(overridedParams.get("objective_type").isDefined, "parameter \"objective_type\" " +
166+
"is not defined, you have to specify the objective type as classification or regression" +
167+
" with a customized objective function")
168+
}
169+
170+
var trainTestRatio = 1.0
171+
if (overridedParams.contains("train_test_ratio")) {
172+
logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
173+
" pass a training and multiple evaluation datasets by passing 'eval_sets' and " +
174+
"'eval_set_names'")
175+
trainTestRatio = overridedParams.get("train_test_ratio").get.asInstanceOf[Double]
176+
}
177+
152178
val nWorkers = overridedParams("num_workers").asInstanceOf[Int]
153179
val round = overridedParams("num_round").asInstanceOf[Int]
154180
val useExternalMemory = overridedParams
155181
.getOrElse("use_external_memory", false).asInstanceOf[Boolean]
156-
val obj = overridedParams.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait]
157-
val eval = overridedParams.getOrElse("custom_eval", null).asInstanceOf[EvalTrait]
182+
158183
val missing = overridedParams.getOrElse("missing", Float.NaN).asInstanceOf[Float]
159184
val allowNonZeroForMissing = overridedParams
160185
.getOrElse("allow_non_zero_for_missing", false)
161186
.asInstanceOf[Boolean]
162-
validateSparkSslConf
163-
var treeMethod: Option[String] = None
164-
if (overridedParams.contains("tree_method")) {
165-
require(overridedParams("tree_method") == "hist" ||
166-
overridedParams("tree_method") == "approx" ||
167-
overridedParams("tree_method") == "auto" ||
168-
overridedParams("tree_method") == "gpu_hist", "xgboost4j-spark only supports tree_method" +
169-
" as 'hist', 'approx', 'gpu_hist', and 'auto'")
170-
treeMethod = Some(overridedParams("tree_method").asInstanceOf[String])
171-
}
172187

188+
val treeMethod: Option[String] = overridedParams.get("tree_method").map(_.toString)
173189
// back-compatible with "gpu_hist"
174190
val device: Option[String] = if (treeMethod.exists(_ == "gpu_hist")) {
175191
Some("cuda")
176192
} else overridedParams.get("device").map(_.toString)
177193

178-
if (overridedParams.contains("train_test_ratio")) {
179-
logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
180-
" pass a training and multiple evaluation datasets by passing 'eval_sets' and " +
181-
"'eval_set_names'")
182-
}
183-
require(nWorkers > 0, "you must specify more than 0 workers")
184-
if (obj != null) {
185-
require(overridedParams.get("objective_type").isDefined, "parameter \"objective_type\" " +
186-
"is not defined, you have to specify the objective type as classification or regression" +
187-
" with a customized objective function")
188-
}
194+
require(!(treeMethod.exists(_ == "approx") && device.exists(_ == "cuda")),
195+
"The tree method \"approx\" is not yet supported for Spark GPU cluster")
196+
189197
val trackerConf = overridedParams.get("tracker_conf") match {
190198
case None => TrackerConf()
191199
case Some(conf: TrackerConf) => conf
192200
case _ => throw new IllegalArgumentException("parameter \"tracker_conf\" must be an " +
193201
"instance of TrackerConf.")
194202
}
195-
val checkpointParam =
196-
ExternalCheckpointParams.extractParams(overridedParams)
197203

198-
val trainTestRatio = overridedParams.getOrElse("train_test_ratio", 1.0)
199-
.asInstanceOf[Double]
204+
val checkpointParam = ExternalCheckpointParams.extractParams(overridedParams)
205+
200206
val seed = overridedParams.getOrElse("seed", System.nanoTime()).asInstanceOf[Long]
201207
val inputParams = XGBoostExecutionInputParams(trainTestRatio, seed)
202208

jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,13 @@ private[spark] trait LearningTaskParams extends Params {
6868
/**
6969
* Fraction of training points to use for testing.
7070
*/
71+
@Deprecated
7172
final val trainTestRatio = new DoubleParam(this, "trainTestRatio",
7273
"fraction of training points to use for testing",
7374
ParamValidators.inRange(0, 1))
7475
setDefault(trainTestRatio, 1.0)
7576

77+
@Deprecated
7678
final def getTrainTestRatio: Double = $(trainTestRatio)
7779

7880
/**

jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,4 +92,15 @@ class ParameterSuite extends AnyFunSuite with PerTest with BeforeAndAfterAll {
9292
classifier.getBaseScore
9393
}
9494
}
95+
96+
test("approx can't be used for gpu train") {
97+
val paramMap = Map("tree_method" -> "approx", "device" -> "cuda")
98+
val trainingDF = buildDataFrame(MultiClassification.train)
99+
val xgb = new XGBoostClassifier(paramMap)
100+
val thrown = intercept[IllegalArgumentException] {
101+
xgb.fit(trainingDF)
102+
}
103+
assert(thrown.getMessage.contains("The tree method \"approx\" is not yet supported " +
104+
"for Spark GPU cluster"))
105+
}
95106
}

0 commit comments

Comments
 (0)