Skip to content

Commit aa9818f

Browse files
trivialfiswbo4958
andauthored
[backport][pyspark] Avoid repartition. (dmlc#10408) (dmlc#10411)
Co-authored-by: Bobby Wang <[email protected]>
1 parent 7e94cbf commit aa9818f

File tree

3 files changed

+28
-49
lines changed

3 files changed

+28
-49
lines changed

doc/tutorials/spark_estimator.rst

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ An example submit command is shown below with additional spark configurations an
267267
--conf spark.task.cpus=1 \
268268
--conf spark.executor.resource.gpu.amount=1 \
269269
--conf spark.task.resource.gpu.amount=0.08 \
270-
--packages com.nvidia:rapids-4-spark_2.12:23.04.0 \
270+
--packages com.nvidia:rapids-4-spark_2.12:24.04.1 \
271271
--conf spark.plugins=com.nvidia.spark.SQLPlugin \
272272
--conf spark.sql.execution.arrow.maxRecordsPerBatch=1000000 \
273273
--archives xgboost_env.tar.gz#environment \
@@ -276,3 +276,21 @@ An example submit command is shown below with additional spark configurations an
276276
When rapids plugin is enabled, both of the JVM rapids plugin and the cuDF Python package
277277
are required. More configuration options can be found in the RAPIDS link above along with
278278
details on the plugin.
279+
280+
Advanced Usage
281+
==============
282+
283+
XGBoost needs to repartition the input dataset to the num_workers to ensure there will be
284+
num_workers training tasks running at the same time. However, repartition is a costly operation.
285+
286+
If there is a scenario where reading the data from source and directly fitting it to XGBoost
287+
without introducing the shuffle stage, users can avoid the need for repartitioning by setting
288+
the Spark configuration parameters ``spark.sql.files.maxPartitionNum`` and
289+
``spark.sql.files.minPartitionNum`` to num_workers. This tells Spark to automatically partition
290+
the dataset into the desired number of partitions.
291+
292+
However, if the input dataset is skewed (i.e. the data is not evenly distributed), setting
293+
the partition number to num_workers may not be efficient. In this case, users can set
294+
the ``force_repartition=true`` option to explicitly force XGBoost to repartition the dataset,
295+
even if the partition number is already equal to num_workers. This ensures the data is evenly
296+
distributed across the workers.

python-package/xgboost/spark/core.py

Lines changed: 8 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -691,50 +691,15 @@ def _convert_to_sklearn_model(self, booster: bytearray, config: str) -> XGBModel
691691
sklearn_model._Booster.load_config(config)
692692
return sklearn_model
693693

694-
def _query_plan_contains_valid_repartition(self, dataset: DataFrame) -> bool:
695-
"""
696-
Returns true if the latest element in the logical plan is a valid repartition
697-
The logic plan string format is like:
698-
699-
== Optimized Logical Plan ==
700-
Repartition 4, true
701-
+- LogicalRDD [features#12, label#13L], false
702-
703-
i.e., the top line in the logical plan is the last operation to execute.
704-
so, in this method, we check the first line, if it is a "Repartition" operation,
705-
and the result dataframe has the same partition number with num_workers param,
706-
then it means the dataframe is well repartitioned and we don't need to
707-
repartition the dataframe again.
708-
"""
709-
num_partitions = dataset.rdd.getNumPartitions()
710-
assert dataset._sc._jvm is not None
711-
query_plan = dataset._sc._jvm.PythonSQLUtils.explainString(
712-
dataset._jdf.queryExecution(), "extended"
713-
)
714-
start = query_plan.index("== Optimized Logical Plan ==")
715-
start += len("== Optimized Logical Plan ==") + 1
716-
num_workers = self.getOrDefault(self.num_workers)
717-
if (
718-
query_plan[start : start + len("Repartition")] == "Repartition"
719-
and num_workers == num_partitions
720-
):
721-
return True
722-
return False
723-
724694
def _repartition_needed(self, dataset: DataFrame) -> bool:
725695
"""
726696
We repartition the dataset if the number of workers is not equal to the number of
727-
partitions. There is also a check to make sure there was "active partitioning"
728-
where either Round Robin or Hash partitioning was actively used before this stage.
729-
"""
697+
partitions."""
730698
if self.getOrDefault(self.force_repartition):
731699
return True
732-
try:
733-
if self._query_plan_contains_valid_repartition(dataset):
734-
return False
735-
except Exception: # pylint: disable=broad-except
736-
pass
737-
return True
700+
num_workers = self.getOrDefault(self.num_workers)
701+
num_partitions = dataset.rdd.getNumPartitions()
702+
return not num_workers == num_partitions
738703

739704
def _get_distributed_train_params(self, dataset: DataFrame) -> Dict[str, Any]:
740705
"""
@@ -871,14 +836,10 @@ def _prepare_input(self, dataset: DataFrame) -> Tuple[DataFrame, FeatureProp]:
871836
num_workers,
872837
)
873838

874-
if self._repartition_needed(dataset) or (
875-
self.isDefined(self.validationIndicatorCol)
876-
and self.getOrDefault(self.validationIndicatorCol) != ""
877-
):
878-
# If validationIndicatorCol defined, we always repartition dataset
879-
# to balance data, because user might unionise train and validation dataset,
880-
# without shuffling data then some partitions might contain only train or validation
881-
# dataset.
839+
if self._repartition_needed(dataset):
840+
# If validationIndicatorCol defined, and if user unionise train and validation
841+
# dataset, users must set force_repartition to true to force repartition.
842+
# Or else some partitions might contain only train or validation dataset.
882843
if self.getOrDefault(self.repartition_random_shuffle):
883844
# In some cases, spark round-robin repartition might cause data skew
884845
# use random shuffle can address it.

tests/test_distributed/test_with_spark/test_spark_local_cluster.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ def test_repartition(self):
474474

475475
classifier = SparkXGBClassifier(num_workers=self.n_workers)
476476
basic = self.cls_df_train_distributed
477-
self.assertTrue(classifier._repartition_needed(basic))
477+
self.assertTrue(not classifier._repartition_needed(basic))
478478
bad_repartitioned = basic.repartition(self.n_workers + 1)
479479
self.assertTrue(classifier._repartition_needed(bad_repartitioned))
480480
good_repartitioned = basic.repartition(self.n_workers)

0 commit comments

Comments
 (0)