@@ -691,50 +691,15 @@ def _convert_to_sklearn_model(self, booster: bytearray, config: str) -> XGBModel
691
691
sklearn_model ._Booster .load_config (config )
692
692
return sklearn_model
693
693
694
- def _query_plan_contains_valid_repartition (self , dataset : DataFrame ) -> bool :
695
- """
696
- Returns true if the latest element in the logical plan is a valid repartition
697
- The logic plan string format is like:
698
-
699
- == Optimized Logical Plan ==
700
- Repartition 4, true
701
- +- LogicalRDD [features#12, label#13L], false
702
-
703
- i.e., the top line in the logical plan is the last operation to execute.
704
- so, in this method, we check the first line, if it is a "Repartition" operation,
705
- and the result dataframe has the same partition number with num_workers param,
706
- then it means the dataframe is well repartitioned and we don't need to
707
- repartition the dataframe again.
708
- """
709
- num_partitions = dataset .rdd .getNumPartitions ()
710
- assert dataset ._sc ._jvm is not None
711
- query_plan = dataset ._sc ._jvm .PythonSQLUtils .explainString (
712
- dataset ._jdf .queryExecution (), "extended"
713
- )
714
- start = query_plan .index ("== Optimized Logical Plan ==" )
715
- start += len ("== Optimized Logical Plan ==" ) + 1
716
- num_workers = self .getOrDefault (self .num_workers )
717
- if (
718
- query_plan [start : start + len ("Repartition" )] == "Repartition"
719
- and num_workers == num_partitions
720
- ):
721
- return True
722
- return False
723
-
724
694
def _repartition_needed (self , dataset : DataFrame ) -> bool :
725
695
"""
726
696
We repartition the dataset if the number of workers is not equal to the number of
727
- partitions. There is also a check to make sure there was "active partitioning"
728
- where either Round Robin or Hash partitioning was actively used before this stage.
729
- """
697
+ partitions."""
730
698
if self .getOrDefault (self .force_repartition ):
731
699
return True
732
- try :
733
- if self ._query_plan_contains_valid_repartition (dataset ):
734
- return False
735
- except Exception : # pylint: disable=broad-except
736
- pass
737
- return True
700
+ num_workers = self .getOrDefault (self .num_workers )
701
+ num_partitions = dataset .rdd .getNumPartitions ()
702
+ return not num_workers == num_partitions
738
703
739
704
def _get_distributed_train_params (self , dataset : DataFrame ) -> Dict [str , Any ]:
740
705
"""
@@ -871,14 +836,10 @@ def _prepare_input(self, dataset: DataFrame) -> Tuple[DataFrame, FeatureProp]:
871
836
num_workers ,
872
837
)
873
838
874
- if self ._repartition_needed (dataset ) or (
875
- self .isDefined (self .validationIndicatorCol )
876
- and self .getOrDefault (self .validationIndicatorCol ) != ""
877
- ):
878
- # If validationIndicatorCol defined, we always repartition dataset
879
- # to balance data, because user might unionise train and validation dataset,
880
- # without shuffling data then some partitions might contain only train or validation
881
- # dataset.
839
+ if self ._repartition_needed (dataset ):
840
+ # If validationIndicatorCol defined, and if user unionise train and validation
841
+ # dataset, users must set force_repartition to true to force repartition.
842
+ # Or else some partitions might contain only train or validation dataset.
882
843
if self .getOrDefault (self .repartition_random_shuffle ):
883
844
# In some cases, spark round-robin repartition might cause data skew
884
845
# use random shuffle can address it.
0 commit comments