Skip to content

Commit 932d720

Browse files
wbo4958trivialfis
andauthored
[jvm-packages] refine tracker (dmlc#10313)
Co-authored-by: Jiaming Yuan <[email protected]>
1 parent 966dc81 commit 932d720

File tree

8 files changed

+71
-92
lines changed

8 files changed

+71
-92
lines changed

jvm-packages/pom.xml

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,17 @@
3535
<maven.compiler.target>1.8</maven.compiler.target>
3636
<flink.version>1.19.0</flink.version>
3737
<junit.version>4.13.2</junit.version>
38-
<spark.version>3.4.1</spark.version>
39-
<spark.version.gpu>3.4.1</spark.version.gpu>
38+
<spark.version>3.5.1</spark.version>
39+
<spark.version.gpu>3.5.1</spark.version.gpu>
40+
<fasterxml.jackson.version>2.15.2</fasterxml.jackson.version>
4041
<scala.version>2.12.18</scala.version>
4142
<scala.binary.version>2.12</scala.binary.version>
4243
<hadoop.version>3.4.0</hadoop.version>
4344
<maven.wagon.http.retryHandler.count>5</maven.wagon.http.retryHandler.count>
4445
<log.capi.invocation>OFF</log.capi.invocation>
4546
<use.cuda>OFF</use.cuda>
46-
<cudf.version>23.12.1</cudf.version>
47-
<spark.rapids.version>23.12.1</spark.rapids.version>
47+
<cudf.version>24.04.0</cudf.version>
48+
<spark.rapids.version>24.04.0</spark.rapids.version>
4849
<cudf.classifier>cuda12</cudf.classifier>
4950
<scalatest.version>3.2.18</scalatest.version>
5051
<scala-collection-compat.version>2.12.0</scala-collection-compat.version>
@@ -489,11 +490,6 @@
489490
<artifactId>kryo</artifactId>
490491
<version>5.6.0</version>
491492
</dependency>
492-
<dependency>
493-
<groupId>com.fasterxml.jackson.core</groupId>
494-
<artifactId>jackson-databind</artifactId>
495-
<version>2.14.2</version>
496-
</dependency>
497493
<dependency>
498494
<groupId>commons-logging</groupId>
499495
<artifactId>commons-logging</artifactId>

jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoost.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ public static XGBoostModel train(DataSet<Tuple2<Vector, Double>> dtrain,
176176
new RabitTracker(dtrain.getExecutionEnvironment().getParallelism());
177177
if (tracker.start()) {
178178
return dtrain
179-
.mapPartition(new MapFunction(params, numBoostRound, tracker.workerArgs()))
179+
.mapPartition(new MapFunction(params, numBoostRound, tracker.getWorkerArgs()))
180180
.reduce((x, y) -> x)
181181
.collect()
182182
.get(0);

jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
Copyright (c) 2021-2022 by Contributors
2+
Copyright (c) 2021-2024 by Contributors
33
44
Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.
@@ -29,7 +29,7 @@ import org.apache.spark.{SparkContext, TaskContext}
2929
import org.apache.spark.ml.{Estimator, Model}
3030
import org.apache.spark.rdd.RDD
3131
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
32-
import org.apache.spark.sql.catalyst.encoders.RowEncoder
32+
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
3333
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
3434
import org.apache.spark.sql.functions.{col, collect_list, struct}
3535
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
@@ -444,7 +444,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
444444
.groupBy(groupName)
445445
.agg(collect_list(struct(schema.fieldNames.map(col): _*)) as "list")
446446

447-
implicit val encoder = RowEncoder(schema)
447+
implicit val encoder = ExpressionEncoder(RowEncoder.encoderFor(schema, false))
448448
// Expand the grouped rows after repartition
449449
repartitionInputData(groupedDF, nWorkers).mapPartitions(iter => {
450450
new Iterator[Row] {

jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala

Lines changed: 27 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -233,24 +233,6 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
233233
xgbExecParam.setRawParamMap(overridedParams)
234234
xgbExecParam
235235
}
236-
237-
private[spark] def buildRabitParams : Map[String, String] = Map(
238-
"rabit_reduce_ring_mincount" ->
239-
overridedParams.getOrElse("rabit_ring_reduce_threshold", 32 << 10).toString,
240-
"rabit_debug" ->
241-
(overridedParams.getOrElse("verbosity", 0).toString.toInt == 3).toString,
242-
"rabit_timeout" ->
243-
(overridedParams.getOrElse("rabit_timeout", -1).toString.toInt >= 0).toString,
244-
"rabit_timeout_sec" -> {
245-
if (overridedParams.getOrElse("rabit_timeout", -1).toString.toInt >= 0) {
246-
overridedParams.get("rabit_timeout").toString
247-
} else {
248-
"1800"
249-
}
250-
},
251-
"DMLC_WORKER_CONNECT_RETRY" ->
252-
overridedParams.getOrElse("dmlc_worker_connect_retry", 5).toString
253-
)
254236
}
255237

256238
/**
@@ -475,17 +457,15 @@ object XGBoost extends XGBoostStageLevel {
475457
}
476458
}
477459

478-
/** visiable for testing */
479-
private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): ITracker = {
480-
val tracker: ITracker = new RabitTracker(
481-
nWorkers, trackerConf.hostIp, trackerConf.port, trackerConf.timeout)
482-
tracker
483-
}
484-
485-
private def startTracker(nWorkers: Int, trackerConf: TrackerConf): ITracker = {
486-
val tracker = getTracker(nWorkers, trackerConf)
460+
// Executes the provided code block inside a tracker and then stops the tracker
461+
private def withTracker[T](nWorkers: Int, conf: TrackerConf)(block: ITracker => T): T = {
462+
val tracker = new RabitTracker(nWorkers, conf.hostIp, conf.port, conf.timeout)
487463
require(tracker.start(), "FAULT: Failed to start tracker")
488-
tracker
464+
try {
465+
block(tracker)
466+
} finally {
467+
tracker.stop()
468+
}
489469
}
490470

491471
/**
@@ -501,55 +481,53 @@ object XGBoost extends XGBoostStageLevel {
501481
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
502482

503483
val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, sc)
504-
val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams
505-
val xgbRabitParams = xgbParamsFactory.buildRabitParams.asJava
484+
val runtimeParams = xgbParamsFactory.buildXGBRuntimeParams
506485

507-
val prevBooster = xgbExecParams.checkpointParam.map { checkpointParam =>
486+
val prevBooster = runtimeParams.checkpointParam.map { checkpointParam =>
508487
val checkpointManager = new ExternalCheckpointManager(
509488
checkpointParam.checkpointPath,
510489
FileSystem.get(sc.hadoopConfiguration))
511-
checkpointManager.cleanUpHigherVersions(xgbExecParams.numRounds)
490+
checkpointManager.cleanUpHigherVersions(runtimeParams.numRounds)
512491
checkpointManager.loadCheckpointAsScalaBooster()
513492
}.orNull
514493

515494
// Get the training data RDD and the cachedRDD
516-
val (trainingRDD, optionalCachedRDD) = buildTrainingData(xgbExecParams)
495+
val (trainingRDD, optionalCachedRDD) = buildTrainingData(runtimeParams)
517496

518497
try {
519-
// Train for every ${savingRound} rounds and save the partially completed booster
520-
val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
521-
val (booster, metrics) = try {
522-
tracker.workerArgs().putAll(xgbRabitParams)
523-
val rabitEnv = tracker.workerArgs
498+
val (booster, metrics) = withTracker(
499+
runtimeParams.numWorkers,
500+
runtimeParams.trackerConf
501+
) { tracker =>
502+
val rabitEnv = tracker.getWorkerArgs()
524503

525-
val boostersAndMetrics = trainingRDD.barrier().mapPartitions { iter => {
504+
val boostersAndMetrics = trainingRDD.barrier().mapPartitions { iter =>
526505
var optionWatches: Option[() => Watches] = None
527506

528507
// take the first Watches to train
529508
if (iter.hasNext) {
530509
optionWatches = Some(iter.next())
531510
}
532511

533-
optionWatches.map { buildWatches => buildDistributedBooster(buildWatches,
534-
xgbExecParams, rabitEnv, xgbExecParams.obj, xgbExecParams.eval, prevBooster)}
535-
.getOrElse(throw new RuntimeException("No Watches to train"))
536-
537-
}}
512+
optionWatches.map { buildWatches =>
513+
buildDistributedBooster(buildWatches,
514+
runtimeParams, rabitEnv, runtimeParams.obj, runtimeParams.eval, prevBooster)
515+
}.getOrElse(throw new RuntimeException("No Watches to train"))
516+
}
538517

539-
val boostersAndMetricsWithRes = tryStageLevelScheduling(sc, xgbExecParams,
518+
val boostersAndMetricsWithRes = tryStageLevelScheduling(sc, runtimeParams,
540519
boostersAndMetrics)
541520
// The repartition step is to make training stage as ShuffleMapStage, so that when one
542521
// of the training task fails the training stage can retry. ResultStage won't retry when
543522
// it fails.
544523
val (booster, metrics) = boostersAndMetricsWithRes.repartition(1).collect()(0)
545524
(booster, metrics)
546-
} finally {
547-
tracker.stop()
548525
}
526+
549527
// we should delete the checkpoint directory after a successful training
550-
xgbExecParams.checkpointParam.foreach {
528+
runtimeParams.checkpointParam.foreach {
551529
cpParam =>
552-
if (!xgbExecParams.checkpointParam.get.skipCleanCheckpoint) {
530+
if (!runtimeParams.checkpointParam.get.skipCleanCheckpoint) {
553531
val checkpointManager = new ExternalCheckpointManager(
554532
cpParam.checkpointPath,
555533
FileSystem.get(sc.hadoopConfiguration))

jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
4545

4646
val tracker = new RabitTracker(numWorkers)
4747
tracker.start()
48-
val trackerEnvs = tracker. workerArgs
48+
val trackerEnvs = tracker.getWorkerArgs
4949

5050
val workerCount: Int = numWorkers
5151
/*
@@ -84,7 +84,7 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
8484
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
8585
val tracker = new RabitTracker(numWorkers)
8686
tracker.start()
87-
val trackerEnvs = tracker.workerArgs
87+
val trackerEnvs = tracker.getWorkerArgs
8888

8989
val workerCount: Int = numWorkers
9090

jvm-packages/xgboost4j/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@
5353
<version>${scalatest.version}</version>
5454
<scope>provided</scope>
5555
</dependency>
56+
<dependency>
57+
<groupId>com.fasterxml.jackson.core</groupId>
58+
<artifactId>jackson-databind</artifactId>
59+
<version>${fasterxml.jackson.version}</version>
60+
<scope>provided</scope>
61+
</dependency>
5662
</dependencies>
5763

5864
<build>

jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ITracker.java

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
*
88
* - start(timeout): Start the tracker awaiting for worker connections, with a given
99
* timeout value (in seconds).
10-
* - workerArgs(): Return the arguments needed to initialize Rabit clients.
10+
* - getWorkerArgs(): Return the arguments needed to initialize Rabit clients.
1111
* - waitFor(timeout): Wait for the task execution by the worker nodes for at most `timeout`
1212
* milliseconds.
1313
*
@@ -21,21 +21,8 @@
2121
* brokers connections between workers.
2222
*/
2323
public interface ITracker extends Thread.UncaughtExceptionHandler {
24-
enum TrackerStatus {
25-
SUCCESS(0), INTERRUPTED(1), TIMEOUT(2), FAILURE(3);
2624

27-
private int statusCode;
28-
29-
TrackerStatus(int statusCode) {
30-
this.statusCode = statusCode;
31-
}
32-
33-
public int getStatusCode() {
34-
return this.statusCode;
35-
}
36-
}
37-
38-
Map<String, Object> workerArgs() throws XGBoostError;
25+
Map<String, Object> getWorkerArgs() throws XGBoostError;
3926

4027
boolean start() throws XGBoostError;
4128

jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
1+
/*
2+
Copyright (c) 2014-2024 by Contributors
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
117
package ml.dmlc.xgboost4j.java;
218

319
import java.util.Map;
@@ -10,14 +26,12 @@
1026

1127
/**
1228
* Java implementation of the Rabit tracker to coordinate distributed workers.
13-
*
14-
* The tracker must be started on driver node before running distributed jobs.
1529
*/
1630
public class RabitTracker implements ITracker {
1731
// Maybe per tracker logger?
1832
private static final Log logger = LogFactory.getLog(RabitTracker.class);
1933
private long handle = 0;
20-
private Thread tracker_daemon;
34+
private Thread trackerDaemon;
2135

2236
public RabitTracker(int numWorkers) throws XGBoostError {
2337
this(numWorkers, "");
@@ -44,24 +58,22 @@ public void uncaughtException(Thread t, Throwable e) {
4458
} catch (InterruptedException ex) {
4559
logger.error(ex);
4660
} finally {
47-
this.tracker_daemon.interrupt();
61+
this.trackerDaemon.interrupt();
4862
}
4963
}
5064

5165
/**
5266
* Get environments that can be used to pass to worker.
5367
* @return The environment settings.
5468
*/
55-
public Map<String, Object> workerArgs() throws XGBoostError {
69+
public Map<String, Object> getWorkerArgs() throws XGBoostError {
5670
// fixme: timeout
5771
String[] args = new String[1];
5872
XGBoostJNI.checkCall(XGBoostJNI.TrackerWorkerArgs(this.handle, 0, args));
5973
ObjectMapper mapper = new ObjectMapper();
60-
TypeReference<Map<String, Object>> typeRef = new TypeReference<Map<String, Object>>() {
61-
};
6274
Map<String, Object> config;
6375
try {
64-
config = mapper.readValue(args[0], typeRef);
76+
config = mapper.readValue(args[0], new TypeReference<Map<String, Object>>() {});
6577
} catch (JsonProcessingException ex) {
6678
throw new XGBoostError("Failed to get worker arguments.", ex);
6779
}
@@ -74,18 +86,18 @@ public void stop() throws XGBoostError {
7486

7587
public boolean start() throws XGBoostError {
7688
XGBoostJNI.checkCall(XGBoostJNI.TrackerRun(this.handle));
77-
this.tracker_daemon = new Thread(() -> {
89+
this.trackerDaemon = new Thread(() -> {
7890
try {
79-
XGBoostJNI.checkCall(XGBoostJNI.TrackerWaitFor(this.handle, 0));
91+
waitFor(0);
8092
} catch (XGBoostError ex) {
8193
logger.error(ex);
8294
return; // exit the thread
8395
}
8496
});
85-
this.tracker_daemon.setDaemon(true);
86-
this.tracker_daemon.start();
97+
this.trackerDaemon.setDaemon(true);
98+
this.trackerDaemon.start();
8799

88-
return this.tracker_daemon.isAlive();
100+
return this.trackerDaemon.isAlive();
89101
}
90102

91103
public void waitFor(long timeout) throws XGBoostError {

0 commit comments

Comments
 (0)