Skip to content

Commit 5a3c593

Browse files
authored
[AutoParallel] add llama rope sub model test (#59854)
1 parent e7716f1 commit 5a3c593

File tree

2 files changed

+268
-0
lines changed

2 files changed

+268
-0
lines changed
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import random
17+
18+
import numpy as np
19+
20+
import paddle
21+
import paddle.distributed as dist
22+
from paddle import nn
23+
from paddle.distributed import Shard
24+
25+
try:
26+
from paddle.incubate.nn.functional import fused_rotary_position_embedding
27+
except ImportError:
28+
fused_rotary_position_embedding = None
29+
30+
BATCH_COUNT = 10
31+
BATCH_SIZE = 16
32+
SEQ_LEN = 128
33+
NUM_HEADS = 8
34+
HEAD_DIM = 64
35+
HIDDEN_SIZE = NUM_HEADS * HEAD_DIM
36+
37+
38+
class RotaryAngle(nn.Layer):
39+
def __init__(self, dim, max_position_embeddings=2048, base=10000):
40+
super().__init__()
41+
self.dim = dim
42+
self.max_position_embeddings = max_position_embeddings
43+
self.base = base
44+
# [dim / 2]
45+
self.inv_freq = 1.0 / (
46+
self.base
47+
** (
48+
paddle.cast(paddle.arange(0, self.dim, 2), dtype="float32")
49+
/ self.dim
50+
)
51+
)
52+
self._set_cos_sin_cache(seq_len=max_position_embeddings)
53+
54+
def _set_cos_sin_cache(self, seq_len):
55+
self.max_seq_len_cached = seq_len
56+
# [seq_len]
57+
t = paddle.arange(seq_len, dtype="float32")
58+
# [seq_len, dim/2]
59+
freqs = paddle.einsum("i,j->ij", t, self.inv_freq)
60+
# [seq_len, dim]
61+
emb = paddle.concat([freqs, freqs], axis=-1)
62+
# [1, seqlen, 1, dim]
63+
self.cos_cached = emb.cos()[None, :, None, :]
64+
self.sin_cached = emb.sin()[None, :, None, :]
65+
66+
def forward(self, x, seq_len=None):
67+
# x: [bs, seq_len, num_heads, head_dim]
68+
cos = self.cos_cached[:, :seq_len, :, :]
69+
sin = self.sin_cached[:, :seq_len, :, :]
70+
return (
71+
cos.cast(x.dtype) if cos.dtype != x.dtype else cos,
72+
sin.cast(x.dtype) if sin.dtype != x.dtype else sin,
73+
)
74+
75+
76+
def rotate_half(x):
77+
"""Rotates half the hidden dims of the input."""
78+
x1 = x[..., : x.shape[-1] // 2]
79+
x2 = x[..., x.shape[-1] // 2 :]
80+
return paddle.concat([-x2, x1], axis=-1) # shape is the same as x
81+
82+
83+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
84+
if position_ids is None:
85+
# Note: Only for LlamaForCausalLMPipe model pretraining
86+
cos = cos[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim]
87+
sin = sin[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim]
88+
else:
89+
cos = cos.squeeze(axis=[0, 2]) # [seq_len, dim]
90+
sin = sin.squeeze(axis=[0, 2]) # [seq_len, dim]
91+
cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
92+
sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
93+
q_embed = (q * cos) + (rotate_half(q) * sin)
94+
k_embed = (k * cos) + (rotate_half(k) * sin)
95+
return q_embed, k_embed
96+
97+
98+
class RotaryPositionEmbedding(nn.Layer):
99+
def __init__(self, seq_len, num_heads, head_dim, is_use_fused_rope=False):
100+
super().__init__()
101+
self.seq_len = seq_len
102+
self.num_heads = num_heads
103+
self.head_dim = head_dim
104+
self.rotary_angle = RotaryAngle(
105+
dim=self.head_dim, max_position_embeddings=self.seq_len
106+
)
107+
self.is_use_fused_rope = is_use_fused_rope
108+
self.hidden_size = self.num_heads * self.head_dim
109+
self.q_proj = nn.Linear(
110+
self.hidden_size,
111+
self.hidden_size,
112+
bias_attr=False,
113+
)
114+
self.k_proj = nn.Linear(
115+
self.hidden_size,
116+
self.hidden_size,
117+
bias_attr=False,
118+
)
119+
120+
def forward(self, input):
121+
target_query_shape = [0, 0, self.num_heads, self.head_dim]
122+
query_states = self.q_proj(input).reshape(shape=target_query_shape)
123+
key_states = self.k_proj(input).reshape(shape=target_query_shape)
124+
125+
cos, sin = self.rotary_angle(query_states, seq_len=self.seq_len)
126+
position_ids = paddle.arange(self.seq_len, dtype="int64").expand(
127+
(BATCH_SIZE, self.seq_len)
128+
)
129+
if self.is_use_fused_rope:
130+
query_states, key_states, _ = fused_rotary_position_embedding(
131+
query_states,
132+
key_states,
133+
v=None,
134+
sin=sin,
135+
cos=cos,
136+
position_ids=position_ids,
137+
use_neox_rotary_style=False,
138+
)
139+
else:
140+
query_states, key_states = apply_rotary_pos_emb(
141+
query_states, key_states, cos, sin, position_ids
142+
)
143+
return query_states, key_states
144+
145+
146+
class TestLlamaRopeSemiAutoParallel:
147+
def __init__(self):
148+
self._dtype = os.getenv("dtype")
149+
self._backend = os.getenv("backend")
150+
self._seed = eval(os.getenv("seed"))
151+
self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
152+
self.is_use_fuse_rope = False
153+
paddle.set_device(self._backend)
154+
self.init_single_card_net_result()
155+
156+
def mp_shard_fn(self, layer_name, layer, process_mesh):
157+
if layer_name == "q_proj" or layer_name == "k_proj":
158+
layer.weight = dist.shard_tensor(
159+
layer.weight, process_mesh, [Shard(1)]
160+
)
161+
162+
def set_use_fuse_rope_flag(self, is_use_fuse_rope):
163+
self.is_use_fuse_rope = is_use_fuse_rope
164+
165+
def set_random_seed(self, seed):
166+
random.seed(seed)
167+
np.random.seed(seed)
168+
paddle.seed(seed)
169+
170+
def init_input_data(self):
171+
input = np.random.random([BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE]).astype(
172+
self._dtype
173+
)
174+
input = paddle.to_tensor(input)
175+
return input
176+
177+
def init_single_card_net_result(self):
178+
self.set_random_seed(self._seed)
179+
rotary_emb = RotaryPositionEmbedding(
180+
seq_len=SEQ_LEN,
181+
num_heads=NUM_HEADS,
182+
head_dim=HEAD_DIM,
183+
is_use_fused_rope=self.is_use_fuse_rope,
184+
)
185+
self.base_out, self.base_parameters = self.train_loop(rotary_emb)
186+
187+
def train_loop(self, layer, shard_input=False):
188+
# run forward and backward
189+
input_dist_attr = [Shard(0)]
190+
191+
opt = paddle.optimizer.SGD(
192+
learning_rate=0.1, parameters=layer.parameters()
193+
)
194+
for _ in range(BATCH_COUNT):
195+
input = self.init_input_data()
196+
if shard_input:
197+
input = dist.shard_tensor(input, self._mesh, input_dist_attr)
198+
query_states, key_states = layer(input)
199+
loss = paddle.sum(query_states + key_states)
200+
loss.backward()
201+
opt.step()
202+
opt.clear_grad()
203+
return loss, layer.parameters()
204+
205+
def check_tensor_eq(self, a, b, rtol=1e-04, atol=1e-05, verbose=True):
206+
if a is None:
207+
assert b is None
208+
return
209+
np1 = a.astype("float32").numpy()
210+
np2 = b.astype("float32").numpy()
211+
np.testing.assert_allclose(
212+
np1, np2, rtol=rtol, atol=atol, verbose=verbose
213+
)
214+
215+
def test_dp(self, is_use_fuse_rope=False):
216+
self.set_random_seed(self._seed)
217+
218+
dp_layer = RotaryPositionEmbedding(
219+
seq_len=SEQ_LEN,
220+
num_heads=NUM_HEADS,
221+
head_dim=HEAD_DIM,
222+
is_use_fused_rope=self.is_use_fuse_rope,
223+
)
224+
225+
dp_out, dp_parameters = self.train_loop(
226+
dp_layer,
227+
shard_input=True,
228+
)
229+
self.check_tensor_eq(dp_out, self.base_out)
230+
for param, param_base in zip(dp_parameters, self.base_parameters):
231+
self.check_tensor_eq(param, param_base)
232+
self.check_tensor_eq(param.grad, param_base.grad)
233+
234+
def test_mp(self, is_use_fuse_rope=False):
235+
self.set_random_seed(self._seed)
236+
237+
mp_layer = RotaryPositionEmbedding(
238+
seq_len=SEQ_LEN,
239+
num_heads=NUM_HEADS,
240+
head_dim=HEAD_DIM,
241+
is_use_fused_rope=self.is_use_fuse_rope,
242+
)
243+
mp_layer = dist.shard_layer(mp_layer, self._mesh, self.mp_shard_fn)
244+
mp_out, mp_parameters = self.train_loop(mp_layer)
245+
self.check_tensor_eq(mp_out, self.base_out)
246+
for param, param_base in zip(mp_parameters, self.base_parameters):
247+
self.check_tensor_eq(param, param_base)
248+
self.check_tensor_eq(param.grad, param_base.grad)
249+
250+
def run_test_case(self):
251+
self.test_dp(is_use_fuse_rope=False)
252+
self.test_mp(is_use_fuse_rope=False)
253+
self.test_dp(is_use_fuse_rope=True)
254+
self.test_mp(is_use_fuse_rope=True)
255+
256+
257+
if __name__ == '__main__':
258+
TestLlamaRopeSemiAutoParallel().run_test_case()

test/auto_parallel/test_semi_auto_parallel_for_llama_subnet.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,16 @@ def test_mlp_subnet(self):
5050
user_defined_envs=envs,
5151
)
5252

53+
def test_rope_subnet(self):
54+
envs_list = test_base.gen_product_envs_list(
55+
{"dtype": "float32", "seed": "2023"}, {"backend": ["gpu"]}
56+
)
57+
for envs in envs_list:
58+
self.run_test_case(
59+
"semi_auto_parallel_for_llama_rope.py",
60+
user_defined_envs=envs,
61+
)
62+
5363
def test_decoder_subnet(self):
5464
envs_list = test_base.gen_product_envs_list(
5565
{"dtype": "float32", "seed": "2023"}, {"backend": ["gpu"]}

0 commit comments

Comments
 (0)