Skip to content

Commit 5973d60

Browse files
[bp][spark] Make xgboost spark support large model size (dmlc#10984) (dmlc#11005)
--------- Signed-off-by: Weichen Xu <[email protected]> Co-authored-by: WeichenXu <[email protected]>
1 parent f199039 commit 5973d60

File tree

2 files changed

+42
-12
lines changed

2 files changed

+42
-12
lines changed

python-package/xgboost/spark/core.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,9 @@ def _get_unwrapped_vec_cols(feature_col: Column) -> List[Column]:
597597
)
598598

599599

600+
_MODEL_CHUNK_SIZE = 4096 * 1024
601+
602+
600603
class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
601604
_input_kwargs: Dict[str, Any]
602605

@@ -1091,25 +1094,27 @@ def _train_booster(
10911094
context.barrier()
10921095

10931096
if context.partitionId() == 0:
1094-
yield pd.DataFrame(
1095-
data={
1096-
"config": [booster.save_config()],
1097-
"booster": [booster.save_raw("json").decode("utf-8")],
1098-
}
1099-
)
1097+
config = booster.save_config()
1098+
yield pd.DataFrame({"data": [config]})
1099+
booster_json = booster.save_raw("json").decode("utf-8")
1100+
1101+
for offset in range(0, len(booster_json), _MODEL_CHUNK_SIZE):
1102+
booster_chunk = booster_json[offset : offset + _MODEL_CHUNK_SIZE]
1103+
yield pd.DataFrame({"data": [booster_chunk]})
11001104

11011105
def _run_job() -> Tuple[str, str]:
11021106
rdd = (
11031107
dataset.mapInPandas(
11041108
_train_booster, # type: ignore
1105-
schema="config string, booster string",
1109+
schema="data string",
11061110
)
11071111
.rdd.barrier()
11081112
.mapPartitions(lambda x: x)
11091113
)
11101114
rdd_with_resource = self._try_stage_level_scheduling(rdd)
1111-
ret = rdd_with_resource.collect()[0]
1112-
return ret[0], ret[1]
1115+
ret = rdd_with_resource.collect()
1116+
data = [v[0] for v in ret]
1117+
return data[0], "".join(data[1:])
11131118

11141119
get_logger(_LOG_TAG).info(
11151120
"Running xgboost-%s on %s workers with"
@@ -1690,7 +1695,12 @@ def saveImpl(self, path: str) -> None:
16901695
_SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger)
16911696
model_save_path = os.path.join(path, "model")
16921697
booster = xgb_model.get_booster().save_raw("json").decode("utf-8")
1693-
_get_spark_session().sparkContext.parallelize([booster], 1).saveAsTextFile(
1698+
booster_chunks = []
1699+
1700+
for offset in range(0, len(booster), _MODEL_CHUNK_SIZE):
1701+
booster_chunks.append(booster[offset : offset + _MODEL_CHUNK_SIZE])
1702+
1703+
_get_spark_session().sparkContext.parallelize(booster_chunks, 1).saveAsTextFile(
16941704
model_save_path
16951705
)
16961706

@@ -1721,8 +1731,8 @@ def load(self, path: str) -> "_SparkXGBModel":
17211731
)
17221732
model_load_path = os.path.join(path, "model")
17231733

1724-
ser_xgb_model = (
1725-
_get_spark_session().sparkContext.textFile(model_load_path).collect()[0]
1734+
ser_xgb_model = "".join(
1735+
_get_spark_session().sparkContext.textFile(model_load_path).collect()
17261736
)
17271737

17281738
def create_xgb_model() -> "XGBModel":

tests/test_distributed/test_with_spark/test_spark_local.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -867,6 +867,26 @@ def test_regressor_model_pipeline_save_load(self, reg_data: RegData) -> None:
867867
)
868868
assert_model_compatible(model.stages[0], tmpdir)
869869

870+
def test_with_small_model_chunk_size(self, reg_data: RegData, monkeypatch) -> None:
871+
import xgboost.spark.core
872+
873+
monkeypatch.setattr(xgboost.spark.core, "_MODEL_CHUNK_SIZE", 4)
874+
with tempfile.TemporaryDirectory() as tmpdir:
875+
path = "file:" + tmpdir
876+
regressor = SparkXGBRegressor(**reg_data.reg_params)
877+
model = regressor.fit(reg_data.reg_df_train)
878+
model.save(path)
879+
loaded_model = SparkXGBRegressorModel.load(path)
880+
assert model.uid == loaded_model.uid
881+
for k, v in reg_data.reg_params.items():
882+
assert loaded_model.getOrDefault(k) == v
883+
884+
pred_result = loaded_model.transform(reg_data.reg_df_test).collect()
885+
for row in pred_result:
886+
assert np.isclose(
887+
row.prediction, row.expected_prediction_with_params, atol=1e-3
888+
)
889+
870890
def test_device_param(self, reg_data: RegData, clf_data: ClfData) -> None:
871891
clf = SparkXGBClassifier(device="cuda", tree_method="exact")
872892
with pytest.raises(ValueError, match="not supported for distributed"):

0 commit comments

Comments
 (0)