@@ -597,6 +597,9 @@ def _get_unwrapped_vec_cols(feature_col: Column) -> List[Column]:
597
597
)
598
598
599
599
600
+ _MODEL_CHUNK_SIZE = 4096 * 1024
601
+
602
+
600
603
class _SparkXGBEstimator (Estimator , _SparkXGBParams , MLReadable , MLWritable ):
601
604
_input_kwargs : Dict [str , Any ]
602
605
@@ -1091,25 +1094,27 @@ def _train_booster(
1091
1094
context .barrier ()
1092
1095
1093
1096
if context .partitionId () == 0 :
1094
- yield pd .DataFrame (
1095
- data = {
1096
- "config" : [booster .save_config ()],
1097
- "booster" : [booster .save_raw ("json" ).decode ("utf-8" )],
1098
- }
1099
- )
1097
+ config = booster .save_config ()
1098
+ yield pd .DataFrame ({"data" : [config ]})
1099
+ booster_json = booster .save_raw ("json" ).decode ("utf-8" )
1100
+
1101
+ for offset in range (0 , len (booster_json ), _MODEL_CHUNK_SIZE ):
1102
+ booster_chunk = booster_json [offset : offset + _MODEL_CHUNK_SIZE ]
1103
+ yield pd .DataFrame ({"data" : [booster_chunk ]})
1100
1104
1101
1105
def _run_job () -> Tuple [str , str ]:
1102
1106
rdd = (
1103
1107
dataset .mapInPandas (
1104
1108
_train_booster , # type: ignore
1105
- schema = "config string, booster string" ,
1109
+ schema = "data string" ,
1106
1110
)
1107
1111
.rdd .barrier ()
1108
1112
.mapPartitions (lambda x : x )
1109
1113
)
1110
1114
rdd_with_resource = self ._try_stage_level_scheduling (rdd )
1111
- ret = rdd_with_resource .collect ()[0 ]
1112
- return ret [0 ], ret [1 ]
1115
+ ret = rdd_with_resource .collect ()
1116
+ data = [v [0 ] for v in ret ]
1117
+ return data [0 ], "" .join (data [1 :])
1113
1118
1114
1119
get_logger (_LOG_TAG ).info (
1115
1120
"Running xgboost-%s on %s workers with"
@@ -1690,7 +1695,12 @@ def saveImpl(self, path: str) -> None:
1690
1695
_SparkXGBSharedReadWrite .saveMetadata (self .instance , path , self .sc , self .logger )
1691
1696
model_save_path = os .path .join (path , "model" )
1692
1697
booster = xgb_model .get_booster ().save_raw ("json" ).decode ("utf-8" )
1693
- _get_spark_session ().sparkContext .parallelize ([booster ], 1 ).saveAsTextFile (
1698
+ booster_chunks = []
1699
+
1700
+ for offset in range (0 , len (booster ), _MODEL_CHUNK_SIZE ):
1701
+ booster_chunks .append (booster [offset : offset + _MODEL_CHUNK_SIZE ])
1702
+
1703
+ _get_spark_session ().sparkContext .parallelize (booster_chunks , 1 ).saveAsTextFile (
1694
1704
model_save_path
1695
1705
)
1696
1706
@@ -1721,8 +1731,8 @@ def load(self, path: str) -> "_SparkXGBModel":
1721
1731
)
1722
1732
model_load_path = os .path .join (path , "model" )
1723
1733
1724
- ser_xgb_model = (
1725
- _get_spark_session ().sparkContext .textFile (model_load_path ).collect ()[ 0 ]
1734
+ ser_xgb_model = "" . join (
1735
+ _get_spark_session ().sparkContext .textFile (model_load_path ).collect ()
1726
1736
)
1727
1737
1728
1738
def create_xgb_model () -> "XGBModel" :
0 commit comments