@@ -93,12 +93,14 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
93
93
94
94
private val overridedParams = overrideParams(rawParams, sc)
95
95
96
+ validateSparkSslConf()
97
+
96
98
/**
97
99
* Check to see if Spark expects SSL encryption (`spark.ssl.enabled` set to true).
98
100
* If so, throw an exception unless this safety measure has been explicitly overridden
99
101
* via conf `xgboost.spark.ignoreSsl`.
100
102
*/
101
- private def validateSparkSslConf : Unit = {
103
+ private def validateSparkSslConf () : Unit = {
102
104
val (sparkSslEnabled : Boolean , xgboostSparkIgnoreSsl : Boolean ) =
103
105
SparkSession .getActiveSession match {
104
106
case Some (ss) =>
@@ -148,55 +150,59 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
148
150
overridedParams
149
151
}
150
152
153
+ /**
154
+ * The Map parameters accepted by estimator's constructor may have string type,
155
+ * Eg, Map("num_workers" -> "6", "num_round" -> 5), we need to convert these
156
+ * kind of parameters into the correct type in the function.
157
+ *
158
+ * @return XGBoostExecutionParams
159
+ */
151
160
def buildXGBRuntimeParams : XGBoostExecutionParams = {
161
+
162
+ val obj = overridedParams.getOrElse(" custom_obj" , null ).asInstanceOf [ObjectiveTrait ]
163
+ val eval = overridedParams.getOrElse(" custom_eval" , null ).asInstanceOf [EvalTrait ]
164
+ if (obj != null ) {
165
+ require(overridedParams.get(" objective_type" ).isDefined, " parameter \" objective_type\" " +
166
+ " is not defined, you have to specify the objective type as classification or regression" +
167
+ " with a customized objective function" )
168
+ }
169
+
170
+ var trainTestRatio = 1.0
171
+ if (overridedParams.contains(" train_test_ratio" )) {
172
+ logger.warn(" train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
173
+ " pass a training and multiple evaluation datasets by passing 'eval_sets' and " +
174
+ " 'eval_set_names'" )
175
+ trainTestRatio = overridedParams.get(" train_test_ratio" ).get.asInstanceOf [Double ]
176
+ }
177
+
152
178
val nWorkers = overridedParams(" num_workers" ).asInstanceOf [Int ]
153
179
val round = overridedParams(" num_round" ).asInstanceOf [Int ]
154
180
val useExternalMemory = overridedParams
155
181
.getOrElse(" use_external_memory" , false ).asInstanceOf [Boolean ]
156
- val obj = overridedParams.getOrElse(" custom_obj" , null ).asInstanceOf [ObjectiveTrait ]
157
- val eval = overridedParams.getOrElse(" custom_eval" , null ).asInstanceOf [EvalTrait ]
182
+
158
183
val missing = overridedParams.getOrElse(" missing" , Float .NaN ).asInstanceOf [Float ]
159
184
val allowNonZeroForMissing = overridedParams
160
185
.getOrElse(" allow_non_zero_for_missing" , false )
161
186
.asInstanceOf [Boolean ]
162
- validateSparkSslConf
163
- var treeMethod : Option [String ] = None
164
- if (overridedParams.contains(" tree_method" )) {
165
- require(overridedParams(" tree_method" ) == " hist" ||
166
- overridedParams(" tree_method" ) == " approx" ||
167
- overridedParams(" tree_method" ) == " auto" ||
168
- overridedParams(" tree_method" ) == " gpu_hist" , " xgboost4j-spark only supports tree_method" +
169
- " as 'hist', 'approx', 'gpu_hist', and 'auto'" )
170
- treeMethod = Some (overridedParams(" tree_method" ).asInstanceOf [String ])
171
- }
172
187
188
+ val treeMethod : Option [String ] = overridedParams.get(" tree_method" ).map(_.toString)
173
189
// back-compatible with "gpu_hist"
174
190
val device : Option [String ] = if (treeMethod.exists(_ == " gpu_hist" )) {
175
191
Some (" cuda" )
176
192
} else overridedParams.get(" device" ).map(_.toString)
177
193
178
- if (overridedParams.contains(" train_test_ratio" )) {
179
- logger.warn(" train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
180
- " pass a training and multiple evaluation datasets by passing 'eval_sets' and " +
181
- " 'eval_set_names'" )
182
- }
183
- require(nWorkers > 0 , " you must specify more than 0 workers" )
184
- if (obj != null ) {
185
- require(overridedParams.get(" objective_type" ).isDefined, " parameter \" objective_type\" " +
186
- " is not defined, you have to specify the objective type as classification or regression" +
187
- " with a customized objective function" )
188
- }
194
+ require(! (treeMethod.exists(_ == " approx" ) && device.exists(_ == " cuda" )),
195
+ " The tree method \" approx\" is not yet supported for Spark GPU cluster" )
196
+
189
197
val trackerConf = overridedParams.get(" tracker_conf" ) match {
190
198
case None => TrackerConf ()
191
199
case Some (conf : TrackerConf ) => conf
192
200
case _ => throw new IllegalArgumentException (" parameter \" tracker_conf\" must be an " +
193
201
" instance of TrackerConf." )
194
202
}
195
- val checkpointParam =
196
- ExternalCheckpointParams .extractParams(overridedParams)
197
203
198
- val trainTestRatio = overridedParams.getOrElse( " train_test_ratio " , 1.0 )
199
- . asInstanceOf [ Double ]
204
+ val checkpointParam = ExternalCheckpointParams .extractParams(overridedParams )
205
+
200
206
val seed = overridedParams.getOrElse(" seed" , System .nanoTime()).asInstanceOf [Long ]
201
207
val inputParams = XGBoostExecutionInputParams (trainTestRatio, seed)
202
208
0 commit comments