diff --git a/paddlenlp/trainer/utils/reshard/common.py b/paddlenlp/trainer/utils/reshard/common.py index 4bdd65ac395e..c6b9ae844c21 100644 --- a/paddlenlp/trainer/utils/reshard/common.py +++ b/paddlenlp/trainer/utils/reshard/common.py @@ -586,6 +586,8 @@ def merge_opt_state(opt_state_map): def split_structure_name_mapping(structure_name_mapping, group_getter): res = OrderedDict() for k, v in structure_name_mapping.items(): + if k not in group_getter.structure_name_mapping: + continue group = group_getter.get_group(k) if group.id not in res: res[group.id] = OrderedDict() diff --git a/paddlenlp/trainer/utils/sharding_io.py b/paddlenlp/trainer/utils/sharding_io.py index 159be1afda25..a045887d63a1 100644 --- a/paddlenlp/trainer/utils/sharding_io.py +++ b/paddlenlp/trainer/utils/sharding_io.py @@ -899,6 +899,10 @@ def _gather_sharding_metas(self): sharding_meta["param_meta_keys"] = ["shape", "dtype", "is_distributed", "no_sync"] sharding_meta["sharding_strategy"] = sharding_strategy sharding_meta["enable_overlap"] = pp_overlap + dp_metas_list = self._all_gather_simple_object(sharding_meta, self.hcg.get_data_parallel_group()) + for e in dp_metas_list: + for key in ["structure_name_mapping", "param_meta"]: + sharding_meta[key].update(e[key]) suffix = self._sharding_meta_suffix() sharding_metas[suffix] = sharding_meta sharding_metas_list = self._all_gather_simple_object(sharding_metas, self.hcg.get_model_parallel_group())