From 1693da5be2f42206bed253cdf4766cf82372eeb9 Mon Sep 17 00:00:00 2001 From: zty-king <17786324919@163.com> Date: Sat, 27 Sep 2025 11:35:30 +0000 Subject: [PATCH] fix_optimizer_init --- .../distributed/flex_checkpoint/aoa/aoa_engine.py | 13 +++++++++---- .../flex_checkpoint/dcp/load_state_dict.py | 4 +++- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/python/paddle/distributed/flex_checkpoint/aoa/aoa_engine.py b/python/paddle/distributed/flex_checkpoint/aoa/aoa_engine.py index 2a7fa85d22cda5..5f472e7ed1e7f9 100644 --- a/python/paddle/distributed/flex_checkpoint/aoa/aoa_engine.py +++ b/python/paddle/distributed/flex_checkpoint/aoa/aoa_engine.py @@ -152,9 +152,12 @@ def __init__( self.output_vars = {} self.need_remove_input_vars = set() self.need_add_output_vars = set() - + self.meaningless_optimizer_key = set() self.shape_propagation() + def get_meaningless_optimizer_key(self): + return self.meaningless_optimizer_key + def make_input_tensor(self, key: str, shape: tuple[int]) -> TensorDesc: base_slice = tuple([slice(0, s) for s in shape]) return TensorDesc([(key, base_slice, base_slice, None)], shape) @@ -380,9 +383,11 @@ def _get_var_ref(var): if name not in self.output_vars: if name in self.need_add_output_vars: self.output_vars[name] = None - else: - assert name in self.input_vars - self.output_vars[name] = self.input_vars[name] + else: # Not from src and not in output_vars: optimizer-created state that cannot be mapped, discard + if name not in self.input_vars: + self.meaningless_optimizer_key.add(name) + else: + self.output_vars[name] = self.input_vars[name] def find_source_slices( self, key: str, local_slice: tuple[slice, ...] diff --git a/python/paddle/distributed/flex_checkpoint/dcp/load_state_dict.py b/python/paddle/distributed/flex_checkpoint/dcp/load_state_dict.py index 21ca0e8a10d7ac..94bbca0ca72006 100644 --- a/python/paddle/distributed/flex_checkpoint/dcp/load_state_dict.py +++ b/python/paddle/distributed/flex_checkpoint/dcp/load_state_dict.py @@ -703,8 +703,10 @@ def _handle_aoa( dst_to_src_desc_mapping = {} new_load_dict = {} src_desc_to_postprocess_list = {} - + meaningless_optimizer_key = aoa_engine.get_meaningless_optimizer_key() for param_name, tgt_shard in load_dict.items(): + if param_name in meaningless_optimizer_key: + continue tgt_desc = build_shard_desc(tgt_shard) shard_mappings = aoa_engine.find_shard_sources(tgt_desc) for mapping in shard_mappings: