Skip to content

Commit fda20a7

Browse files
liuzhenhai93liuzhenhai93
andauthored
sharding reshard 兼容老版本 (#7734)
* polish * polish --------- Co-authored-by: liuzhenhai93 <[email protected]>
1 parent 6e9d465 commit fda20a7

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

paddlenlp/trainer/utils/sharding_io.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -226,10 +226,11 @@ def _need_reshard(self, checkpoint):
226226
if optimizer._param2rank[k] != int(v):
227227
return True
228228
else:
229-
# reshard anyway
230-
if "pp_overlap" not in sharding_meta:
231-
return True
232-
pp_overlap = sharding_meta["pp_overlap"]
229+
pp_overlap = None
230+
# backward compatibility
231+
if "enable_overlap" in sharding_meta:
232+
pp_overlap = sharding_meta["enable_overlap"]
233+
233234
cur_pp_overlap = unwrap_optimizer(self.optimizer, DygraphShardingOptimizerV2).pp_overlap
234235
return pp_overlap != cur_pp_overlap
235236

@@ -543,7 +544,7 @@ def _gather_sharding_metas(self):
543544
sharding_meta["param2rank"] = param2rank
544545
sharding_meta["structure_name_mapping"] = structure_name_mapping
545546
sharding_meta["sharding_strategy"] = sharding_strategy
546-
sharding_meta["pp_overlap"] = pp_overlap
547+
sharding_meta["enable_overlap"] = pp_overlap
547548
suffix = f"tp{self.args.tensor_parallel_rank:0>2d}_pp{self.args.pipeline_parallel_rank:0>2d}"
548549
sharding_metas[suffix] = sharding_meta
549550
sharding_metas_list = self._all_gather_simple_object(sharding_metas, self.hcg.get_model_parallel_group())

0 commit comments

Comments
 (0)