File tree Expand file tree Collapse file tree 1 file changed +6
-5
lines changed Expand file tree Collapse file tree 1 file changed +6
-5
lines changed Original file line number Diff line number Diff line change @@ -226,10 +226,11 @@ def _need_reshard(self, checkpoint):
226
226
if optimizer ._param2rank [k ] != int (v ):
227
227
return True
228
228
else :
229
- # reshard anyway
230
- if "pp_overlap" not in sharding_meta :
231
- return True
232
- pp_overlap = sharding_meta ["pp_overlap" ]
229
+ pp_overlap = None
230
+ # backward compatibility
231
+ if "enable_overlap" in sharding_meta :
232
+ pp_overlap = sharding_meta ["enable_overlap" ]
233
+
233
234
cur_pp_overlap = unwrap_optimizer (self .optimizer , DygraphShardingOptimizerV2 ).pp_overlap
234
235
return pp_overlap != cur_pp_overlap
235
236
@@ -543,7 +544,7 @@ def _gather_sharding_metas(self):
543
544
sharding_meta ["param2rank" ] = param2rank
544
545
sharding_meta ["structure_name_mapping" ] = structure_name_mapping
545
546
sharding_meta ["sharding_strategy" ] = sharding_strategy
546
- sharding_meta ["pp_overlap " ] = pp_overlap
547
+ sharding_meta ["enable_overlap " ] = pp_overlap
547
548
suffix = f"tp{ self .args .tensor_parallel_rank :0>2d} _pp{ self .args .pipeline_parallel_rank :0>2d} "
548
549
sharding_metas [suffix ] = sharding_meta
549
550
sharding_metas_list = self ._all_gather_simple_object (sharding_metas , self .hcg .get_model_parallel_group ())
You can’t perform that action at this time.
0 commit comments