Skip to content

Commit 2dab658

Browse files
authored
fix aoa transpose corner case (PaddlePaddle#76234)
1 parent 6cab4e7 commit 2dab658

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

python/paddle/distributed/flex_checkpoint/dcp/load_state_dict.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -744,6 +744,7 @@ def _handle_aoa(
744744
and src_desc.global_shape == dst_desc.global_shape
745745
and src_desc.global_offset == dst_desc.global_offset
746746
and src_desc.dtype == dst_desc.dtype
747+
and mapping.postprocess_list is None
747748
):
748749
new_load_dict[idx] = ShardedWeight(
749750
key=src_desc.key,

test/flex_checkpoint/load_state_dict_transpose_logic.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ def sharded_state_dict(
4343

4444

4545
class SimpleMLP(Layer):
46-
def __init__(self, hidden_size=1024):
46+
def __init__(self, in_features=1024, out_features=1024):
4747
super().__init__()
4848
self.linear = ColumnParallelLinear(
49-
hidden_size, hidden_size * 2, has_bias=True
49+
in_features, out_features, has_bias=True
5050
)
5151

5252
def forward(self, x):
@@ -55,10 +55,10 @@ def forward(self, x):
5555

5656

5757
class SimpleMLPTransWeight(Layer):
58-
def __init__(self, hidden_size=1024):
58+
def __init__(self, in_features=1024, out_features=1024):
5959
super().__init__()
6060
self.linear = ColumnParallelLinearTransWeight(
61-
hidden_size, hidden_size * 2, has_bias=True
61+
in_features, out_features, has_bias=True
6262
)
6363

6464
def forward(self, x):
@@ -70,6 +70,8 @@ class TestLoadStateDictTransposeLogic:
7070
def __init__(self):
7171
self.aoa_config = {"aoa_statements": [os.getenv("aoa_statements")]}
7272
self.ckpt_path = tempfile.TemporaryDirectory().name
73+
self.in_features = 1024
74+
self.out_features = 2048
7375

7476
def run_test(self):
7577
self.run_save_state_dict()
@@ -99,5 +101,13 @@ def run_save_state_dict(self):
99101
dist.save_state_dict(sharded_state_dict, self.ckpt_path)
100102

101103

104+
class TestLoadStateDictTransposeLogic2(TestLoadStateDictTransposeLogic):
105+
def __init__(self):
106+
super().__init__()
107+
self.in_features = 1024
108+
self.out_features = 1024
109+
110+
102111
if __name__ == '__main__':
103112
TestLoadStateDictTransposeLogic().run_test()
113+
TestLoadStateDictTransposeLogic2().run_test()

0 commit comments

Comments
 (0)