Skip to content

Commit 434f970

Browse files
[AutoParallel] Add testcase for sequence parallel in dygraph mode. (#59841)
* [AutoParallel] Fix problems of sequence parallel in dynamic mode. * Polish code. * Remove TODO in transpose.cc * Polish code. * Remove useless modification. * Polish code. * Polish code. * Remove useless modification. * Allow partial status flow * add 3D auto_parallel test. * add 3d test and fix reshard bug. * Add sequence parallel for llama. * Polish code according to review comments. * Fix bug of backward set in_grad dist_attr. * [AutoParalel] Add testcase for sequence parallel in dygraph mode. * Polish. * [AutoParalel] Add testcase for sequence parallel in dygraph mode. * Change place where sp call reshard --------- Co-authored-by: wuhuachaocoding <[email protected]>
1 parent c9ecd07 commit 434f970

File tree

2 files changed

+269
-0
lines changed

2 files changed

+269
-0
lines changed
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
import numpy as np
18+
from auto_parallel.semi_auto_parallel_simple_net import (
19+
TestSimpleNetForSemiAutoParallel,
20+
create_numpy_like_random,
21+
)
22+
23+
import paddle
24+
import paddle.distributed as dist
25+
from paddle import nn
26+
from paddle.distributed import Replicate, Shard
27+
28+
BATCH_SIZE = 8
29+
SEQUENCE_LEN = 512
30+
HIDDEN_SIZE = 1024
31+
NUM_HEAD = 64
32+
HEAD_DIM = 16
33+
CLASS_NUM = 10
34+
35+
np.set_printoptions(threshold=np.inf)
36+
37+
38+
class DemoNet(nn.Layer):
39+
def __init__(self, param_prefix="", is_sp=False, is_dp=False):
40+
super().__init__()
41+
42+
if is_dp:
43+
self.pp0_mesh = dist.ProcessMesh(
44+
[[0, 1], [2, 3]], dim_names=["dp", "mp"]
45+
)
46+
self.pp1_mesh = dist.ProcessMesh(
47+
[[4, 5], [6, 7]], dim_names=["dp", "mp"]
48+
)
49+
self.placement0 = [Replicate(), Shard(1)]
50+
self.placement1 = [Replicate(), Shard(0)]
51+
self.sp_reshard_placement0 = [Shard(1), Shard(0)]
52+
self.sp_reshard_placement1 = [Shard(1), Replicate()]
53+
else:
54+
self.pp0_mesh = dist.ProcessMesh([0, 1], dim_names=["mp"])
55+
self.pp1_mesh = dist.ProcessMesh([2, 3], dim_names=["mp"])
56+
self.placement0 = [Shard(1)]
57+
self.placement1 = [Shard(0)]
58+
self.sp_reshard_placement0 = [Shard(0)]
59+
self.sp_reshard_placement1 = [Replicate()]
60+
61+
self.is_sp = is_sp
62+
self.is_dp = is_dp
63+
64+
self.norm = nn.LayerNorm(HIDDEN_SIZE, epsilon=1e-4)
65+
self.linear_0_weight = dist.shard_tensor(
66+
self.create_parameter(
67+
shape=[HEAD_DIM, 4 * HIDDEN_SIZE],
68+
attr=create_numpy_like_random(param_prefix + "w_0"),
69+
dtype=paddle.float32,
70+
is_bias=False,
71+
),
72+
self.pp0_mesh,
73+
self.placement0,
74+
)
75+
76+
self.linear_1_weight = dist.shard_tensor(
77+
self.create_parameter(
78+
shape=[4 * HIDDEN_SIZE, HEAD_DIM],
79+
attr=create_numpy_like_random(param_prefix + "w_1"),
80+
dtype=paddle.float32,
81+
is_bias=False,
82+
),
83+
self.pp0_mesh,
84+
self.placement1,
85+
)
86+
87+
self.linear_2_weight = dist.shard_tensor(
88+
self.create_parameter(
89+
shape=[HIDDEN_SIZE, 4 * HIDDEN_SIZE],
90+
attr=create_numpy_like_random(param_prefix + "w_2"),
91+
dtype=paddle.float32,
92+
is_bias=False,
93+
),
94+
self.pp1_mesh,
95+
self.placement0,
96+
)
97+
98+
self.linear_3_weight = dist.shard_tensor(
99+
self.create_parameter(
100+
shape=[4 * HIDDEN_SIZE, CLASS_NUM],
101+
attr=create_numpy_like_random(param_prefix + "w_3"),
102+
dtype=paddle.float32,
103+
is_bias=False,
104+
),
105+
self.pp1_mesh,
106+
self.placement1,
107+
)
108+
109+
def forward(self, x):
110+
# Layer 0
111+
tgt = paddle.transpose(x, [1, 0, 2])
112+
out = paddle.reshape(x, [BATCH_SIZE, SEQUENCE_LEN, NUM_HEAD, HEAD_DIM])
113+
# [BATCH_SIZE, SEQUENCE_LEN, NUM_HEAD, HEAD_DIM] -> [BATCH_SIZE, NUM_HEAD, SEQUENCE_LEN, HEAD_DIM]
114+
out = paddle.transpose(out, [0, 2, 1, 3])
115+
out = paddle.matmul(out, self.linear_0_weight)
116+
out = paddle.matmul(out, self.linear_1_weight)
117+
out = paddle.transpose(out, [2, 0, 1, 3])
118+
out = paddle.reshape(out, [SEQUENCE_LEN, BATCH_SIZE, HIDDEN_SIZE])
119+
120+
# SP Region, should be reduce_scatter
121+
if self.is_sp:
122+
out = dist.reshard(out, self.pp0_mesh, self.sp_reshard_placement0)
123+
124+
# out = out + tgt
125+
out = self.norm(out)
126+
127+
out = dist.reshard(out, self.pp1_mesh, self.sp_reshard_placement1)
128+
129+
out = paddle.matmul(out, self.linear_2_weight)
130+
out = paddle.matmul(out, self.linear_3_weight)
131+
out = paddle.transpose(out, [1, 0, 2])
132+
133+
return out
134+
135+
136+
class TestSimpleNetHybridStrategyForSemiAutoParallel(
137+
TestSimpleNetForSemiAutoParallel
138+
):
139+
def __init__(self):
140+
self._dtype = os.getenv("dtype")
141+
self._backend = os.getenv("backend")
142+
self._seed = eval(os.getenv("seed"))
143+
self._is_dp = os.getenv("is_dp") == "true"
144+
if self._is_dp:
145+
self.pp0_mesh = dist.ProcessMesh(
146+
[[0, 1], [2, 3]], dim_names=["dp", "mp"]
147+
)
148+
149+
paddle.set_device(self._backend)
150+
151+
self.set_random_seed(self._seed)
152+
self.init_single_card_net_result()
153+
154+
def init_single_card_net_result(self):
155+
self.set_random_seed(self._seed)
156+
self.base_loss, self.base_parameters = self.run_dynamic(
157+
DemoNet("demo_weight", is_sp=False, is_dp=self._is_dp), is_sp=False
158+
)
159+
160+
def init_input_data(self):
161+
image = np.random.randn(BATCH_SIZE, SEQUENCE_LEN, HIDDEN_SIZE).astype(
162+
'float32'
163+
)
164+
label = np.random.randn(BATCH_SIZE, SEQUENCE_LEN, CLASS_NUM).astype(
165+
'float32'
166+
)
167+
168+
return paddle.to_tensor(image), paddle.to_tensor(label)
169+
170+
def check_tensor_eq(self, a, b, rtol=1e-7, atol=0, verbose=True):
171+
np1 = a.astype("float32").numpy()
172+
np2 = b.astype("float32").numpy()
173+
np.testing.assert_allclose(
174+
np1, np2, rtol=rtol, atol=atol, verbose=verbose
175+
)
176+
177+
def run_dynamic(self, layer, is_sp=False):
178+
# create loss
179+
loss_fn = nn.MSELoss()
180+
# run forward and backward
181+
opt = paddle.optimizer.AdamW(
182+
learning_rate=0.1, parameters=layer.parameters()
183+
)
184+
for _ in range(5):
185+
image, label = self.init_input_data()
186+
if self._is_dp:
187+
image = dist.shard_tensor(
188+
image, self.pp0_mesh, [Shard(0), Replicate()]
189+
)
190+
191+
out = layer(image)
192+
193+
loss = loss_fn(out, label)
194+
loss.backward()
195+
196+
opt.step()
197+
return loss, layer.parameters()
198+
199+
def test_dp_mp_sp_demo_net(self):
200+
self.set_random_seed(self._seed)
201+
model = DemoNet("dp_mp_hybrid_strategy", is_sp=True, is_dp=self._is_dp)
202+
203+
(
204+
self.dp_mp_sp_loss,
205+
self.dp_mp_sp_parameters,
206+
) = self.run_dynamic(model, is_sp=True)
207+
208+
self.check_tensor_eq(self.dp_mp_sp_loss, self.base_loss)
209+
for param, param_base in zip(
210+
self.dp_mp_sp_parameters, self.base_parameters
211+
):
212+
if param.grad is not None:
213+
self.check_tensor_eq(param, param_base)
214+
self.check_tensor_eq(param.grad, param_base.grad)
215+
216+
def run_test_case(self):
217+
self.test_dp_mp_sp_demo_net()
218+
219+
220+
if __name__ == '__main__':
221+
TestSimpleNetHybridStrategyForSemiAutoParallel().run_test_case()

test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_hybrid_strategy.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,54 @@ def test_simple_net_hybrid_strategy(self):
7070
ckpt_path.cleanup()
7171

7272

73+
class TestSemiAutoParallelHybridStrategyWithSP(
74+
test_base.CommunicationTestDistBase
75+
):
76+
def setUp(self):
77+
super().setUp(
78+
num_of_devices=4,
79+
timeout=120,
80+
nnode=1,
81+
)
82+
self._default_envs = {
83+
"dtype": "float32",
84+
"seed": "2023",
85+
}
86+
self._changeable_envs = {"backend": ["gpu"], "is_dp": ["false"]}
87+
88+
def test_simple_net_mp_pp_sp(self):
89+
envs_list = test_base.gen_product_envs_list(
90+
self._default_envs, self._changeable_envs
91+
)
92+
for envs in envs_list:
93+
ckpt_path = tempfile.TemporaryDirectory()
94+
envs["ckpt_path"] = ckpt_path.name
95+
self.run_test_case(
96+
"semi_auto_parallel_simple_net_sp.py",
97+
user_defined_envs=envs,
98+
)
99+
ckpt_path.cleanup()
100+
101+
def test_simple_net_dp_mp_pp_sp(self):
102+
super().setUp(
103+
num_of_devices=8,
104+
timeout=120,
105+
nnode=1,
106+
)
107+
self._changeable_envs = {"backend": ["gpu"], "is_dp": ["true"]}
108+
envs_list = test_base.gen_product_envs_list(
109+
self._default_envs, self._changeable_envs
110+
)
111+
for envs in envs_list:
112+
ckpt_path = tempfile.TemporaryDirectory()
113+
envs["ckpt_path"] = ckpt_path.name
114+
self.run_test_case(
115+
"semi_auto_parallel_simple_net_sp.py",
116+
user_defined_envs=envs,
117+
)
118+
ckpt_path.cleanup()
119+
120+
73121
class TestSemiAutoParallelCrossMeshReshard(test_base.CommunicationTestDistBase):
74122
def setUp(self):
75123
super().setUp(

0 commit comments

Comments
 (0)