Skip to content

Commit a88f07c

Browse files
authored
【FlexCheckpoint】upgrad full param (PaddlePaddle#76510)
* upgrad_full_param * fix * fix * fix comment * add test remove
1 parent d24c763 commit a88f07c

File tree

5 files changed

+978
-449
lines changed

5 files changed

+978
-449
lines changed

python/paddle/distributed/flex_checkpoint/aoa/aoa_engine.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,11 @@ def _get_var_ref(var):
576576
# from source_state_shard_info and aha_statements. In this case, all destination_states
577577
# remain unsharded (not partitioned).
578578
for name, ref_t in self.input_vars.items():
579-
if name not in self.output_vars and ref_t.out_degree == 0:
579+
if (
580+
name not in self.output_vars
581+
and ref_t.out_degree == 0
582+
and name not in self.need_remove_input_vars
583+
):
580584
self.output_vars[name] = self.identity(ref_t)
581585
for name, ref_t in self.intermediate_vars.items():
582586
if name not in self.output_vars and ref_t.out_degree == 0:

0 commit comments

Comments
 (0)