|
| 1 | +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import os |
| 16 | +import random |
| 17 | + |
| 18 | +import numpy as np |
| 19 | + |
| 20 | +import paddle |
| 21 | +import paddle.distributed as dist |
| 22 | +from paddle import nn |
| 23 | +from paddle.distributed import Shard |
| 24 | + |
| 25 | +try: |
| 26 | + from paddle.incubate.nn.functional import fused_rotary_position_embedding |
| 27 | +except ImportError: |
| 28 | + fused_rotary_position_embedding = None |
| 29 | + |
| 30 | +BATCH_COUNT = 10 |
| 31 | +BATCH_SIZE = 16 |
| 32 | +SEQ_LEN = 128 |
| 33 | +NUM_HEADS = 8 |
| 34 | +HEAD_DIM = 64 |
| 35 | +HIDDEN_SIZE = NUM_HEADS * HEAD_DIM |
| 36 | + |
| 37 | + |
| 38 | +class RotaryAngle(nn.Layer): |
| 39 | + def __init__(self, dim, max_position_embeddings=2048, base=10000): |
| 40 | + super().__init__() |
| 41 | + self.dim = dim |
| 42 | + self.max_position_embeddings = max_position_embeddings |
| 43 | + self.base = base |
| 44 | + # [dim / 2] |
| 45 | + self.inv_freq = 1.0 / ( |
| 46 | + self.base |
| 47 | + ** ( |
| 48 | + paddle.cast(paddle.arange(0, self.dim, 2), dtype="float32") |
| 49 | + / self.dim |
| 50 | + ) |
| 51 | + ) |
| 52 | + self._set_cos_sin_cache(seq_len=max_position_embeddings) |
| 53 | + |
| 54 | + def _set_cos_sin_cache(self, seq_len): |
| 55 | + self.max_seq_len_cached = seq_len |
| 56 | + # [seq_len] |
| 57 | + t = paddle.arange(seq_len, dtype="float32") |
| 58 | + # [seq_len, dim/2] |
| 59 | + freqs = paddle.einsum("i,j->ij", t, self.inv_freq) |
| 60 | + # [seq_len, dim] |
| 61 | + emb = paddle.concat([freqs, freqs], axis=-1) |
| 62 | + # [1, seqlen, 1, dim] |
| 63 | + self.cos_cached = emb.cos()[None, :, None, :] |
| 64 | + self.sin_cached = emb.sin()[None, :, None, :] |
| 65 | + |
| 66 | + def forward(self, x, seq_len=None): |
| 67 | + # x: [bs, seq_len, num_heads, head_dim] |
| 68 | + cos = self.cos_cached[:, :seq_len, :, :] |
| 69 | + sin = self.sin_cached[:, :seq_len, :, :] |
| 70 | + return ( |
| 71 | + cos.cast(x.dtype) if cos.dtype != x.dtype else cos, |
| 72 | + sin.cast(x.dtype) if sin.dtype != x.dtype else sin, |
| 73 | + ) |
| 74 | + |
| 75 | + |
| 76 | +def rotate_half(x): |
| 77 | + """Rotates half the hidden dims of the input.""" |
| 78 | + x1 = x[..., : x.shape[-1] // 2] |
| 79 | + x2 = x[..., x.shape[-1] // 2 :] |
| 80 | + return paddle.concat([-x2, x1], axis=-1) # shape is the same as x |
| 81 | + |
| 82 | + |
| 83 | +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): |
| 84 | + if position_ids is None: |
| 85 | + # Note: Only for LlamaForCausalLMPipe model pretraining |
| 86 | + cos = cos[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim] |
| 87 | + sin = sin[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim] |
| 88 | + else: |
| 89 | + cos = cos.squeeze(axis=[0, 2]) # [seq_len, dim] |
| 90 | + sin = sin.squeeze(axis=[0, 2]) # [seq_len, dim] |
| 91 | + cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] |
| 92 | + sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] |
| 93 | + q_embed = (q * cos) + (rotate_half(q) * sin) |
| 94 | + k_embed = (k * cos) + (rotate_half(k) * sin) |
| 95 | + return q_embed, k_embed |
| 96 | + |
| 97 | + |
| 98 | +class RotaryPositionEmbedding(nn.Layer): |
| 99 | + def __init__(self, seq_len, num_heads, head_dim, is_use_fused_rope=False): |
| 100 | + super().__init__() |
| 101 | + self.seq_len = seq_len |
| 102 | + self.num_heads = num_heads |
| 103 | + self.head_dim = head_dim |
| 104 | + self.rotary_angle = RotaryAngle( |
| 105 | + dim=self.head_dim, max_position_embeddings=self.seq_len |
| 106 | + ) |
| 107 | + self.is_use_fused_rope = is_use_fused_rope |
| 108 | + self.hidden_size = self.num_heads * self.head_dim |
| 109 | + self.q_proj = nn.Linear( |
| 110 | + self.hidden_size, |
| 111 | + self.hidden_size, |
| 112 | + bias_attr=False, |
| 113 | + ) |
| 114 | + self.k_proj = nn.Linear( |
| 115 | + self.hidden_size, |
| 116 | + self.hidden_size, |
| 117 | + bias_attr=False, |
| 118 | + ) |
| 119 | + |
| 120 | + def forward(self, input): |
| 121 | + target_query_shape = [0, 0, self.num_heads, self.head_dim] |
| 122 | + query_states = self.q_proj(input).reshape(shape=target_query_shape) |
| 123 | + key_states = self.k_proj(input).reshape(shape=target_query_shape) |
| 124 | + |
| 125 | + cos, sin = self.rotary_angle(query_states, seq_len=self.seq_len) |
| 126 | + position_ids = paddle.arange(self.seq_len, dtype="int64").expand( |
| 127 | + (BATCH_SIZE, self.seq_len) |
| 128 | + ) |
| 129 | + if self.is_use_fused_rope: |
| 130 | + query_states, key_states, _ = fused_rotary_position_embedding( |
| 131 | + query_states, |
| 132 | + key_states, |
| 133 | + v=None, |
| 134 | + sin=sin, |
| 135 | + cos=cos, |
| 136 | + position_ids=position_ids, |
| 137 | + use_neox_rotary_style=False, |
| 138 | + ) |
| 139 | + else: |
| 140 | + query_states, key_states = apply_rotary_pos_emb( |
| 141 | + query_states, key_states, cos, sin, position_ids |
| 142 | + ) |
| 143 | + return query_states, key_states |
| 144 | + |
| 145 | + |
| 146 | +class TestLlamaRopeSemiAutoParallel: |
| 147 | + def __init__(self): |
| 148 | + self._dtype = os.getenv("dtype") |
| 149 | + self._backend = os.getenv("backend") |
| 150 | + self._seed = eval(os.getenv("seed")) |
| 151 | + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) |
| 152 | + self.is_use_fuse_rope = False |
| 153 | + paddle.set_device(self._backend) |
| 154 | + self.init_single_card_net_result() |
| 155 | + |
| 156 | + def mp_shard_fn(self, layer_name, layer, process_mesh): |
| 157 | + if layer_name == "q_proj" or layer_name == "k_proj": |
| 158 | + layer.weight = dist.shard_tensor( |
| 159 | + layer.weight, process_mesh, [Shard(1)] |
| 160 | + ) |
| 161 | + |
| 162 | + def set_use_fuse_rope_flag(self, is_use_fuse_rope): |
| 163 | + self.is_use_fuse_rope = is_use_fuse_rope |
| 164 | + |
| 165 | + def set_random_seed(self, seed): |
| 166 | + random.seed(seed) |
| 167 | + np.random.seed(seed) |
| 168 | + paddle.seed(seed) |
| 169 | + |
| 170 | + def init_input_data(self): |
| 171 | + input = np.random.random([BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE]).astype( |
| 172 | + self._dtype |
| 173 | + ) |
| 174 | + input = paddle.to_tensor(input) |
| 175 | + return input |
| 176 | + |
| 177 | + def init_single_card_net_result(self): |
| 178 | + self.set_random_seed(self._seed) |
| 179 | + rotary_emb = RotaryPositionEmbedding( |
| 180 | + seq_len=SEQ_LEN, |
| 181 | + num_heads=NUM_HEADS, |
| 182 | + head_dim=HEAD_DIM, |
| 183 | + is_use_fused_rope=self.is_use_fuse_rope, |
| 184 | + ) |
| 185 | + self.base_out, self.base_parameters = self.train_loop(rotary_emb) |
| 186 | + |
| 187 | + def train_loop(self, layer, shard_input=False): |
| 188 | + # run forward and backward |
| 189 | + input_dist_attr = [Shard(0)] |
| 190 | + |
| 191 | + opt = paddle.optimizer.SGD( |
| 192 | + learning_rate=0.1, parameters=layer.parameters() |
| 193 | + ) |
| 194 | + for _ in range(BATCH_COUNT): |
| 195 | + input = self.init_input_data() |
| 196 | + if shard_input: |
| 197 | + input = dist.shard_tensor(input, self._mesh, input_dist_attr) |
| 198 | + query_states, key_states = layer(input) |
| 199 | + loss = paddle.sum(query_states + key_states) |
| 200 | + loss.backward() |
| 201 | + opt.step() |
| 202 | + opt.clear_grad() |
| 203 | + return loss, layer.parameters() |
| 204 | + |
| 205 | + def check_tensor_eq(self, a, b, rtol=1e-04, atol=1e-05, verbose=True): |
| 206 | + if a is None: |
| 207 | + assert b is None |
| 208 | + return |
| 209 | + np1 = a.astype("float32").numpy() |
| 210 | + np2 = b.astype("float32").numpy() |
| 211 | + np.testing.assert_allclose( |
| 212 | + np1, np2, rtol=rtol, atol=atol, verbose=verbose |
| 213 | + ) |
| 214 | + |
| 215 | + def test_dp(self, is_use_fuse_rope=False): |
| 216 | + self.set_random_seed(self._seed) |
| 217 | + |
| 218 | + dp_layer = RotaryPositionEmbedding( |
| 219 | + seq_len=SEQ_LEN, |
| 220 | + num_heads=NUM_HEADS, |
| 221 | + head_dim=HEAD_DIM, |
| 222 | + is_use_fused_rope=self.is_use_fuse_rope, |
| 223 | + ) |
| 224 | + |
| 225 | + dp_out, dp_parameters = self.train_loop( |
| 226 | + dp_layer, |
| 227 | + shard_input=True, |
| 228 | + ) |
| 229 | + self.check_tensor_eq(dp_out, self.base_out) |
| 230 | + for param, param_base in zip(dp_parameters, self.base_parameters): |
| 231 | + self.check_tensor_eq(param, param_base) |
| 232 | + self.check_tensor_eq(param.grad, param_base.grad) |
| 233 | + |
| 234 | + def test_mp(self, is_use_fuse_rope=False): |
| 235 | + self.set_random_seed(self._seed) |
| 236 | + |
| 237 | + mp_layer = RotaryPositionEmbedding( |
| 238 | + seq_len=SEQ_LEN, |
| 239 | + num_heads=NUM_HEADS, |
| 240 | + head_dim=HEAD_DIM, |
| 241 | + is_use_fused_rope=self.is_use_fuse_rope, |
| 242 | + ) |
| 243 | + mp_layer = dist.shard_layer(mp_layer, self._mesh, self.mp_shard_fn) |
| 244 | + mp_out, mp_parameters = self.train_loop(mp_layer) |
| 245 | + self.check_tensor_eq(mp_out, self.base_out) |
| 246 | + for param, param_base in zip(mp_parameters, self.base_parameters): |
| 247 | + self.check_tensor_eq(param, param_base) |
| 248 | + self.check_tensor_eq(param.grad, param_base.grad) |
| 249 | + |
| 250 | + def run_test_case(self): |
| 251 | + self.test_dp(is_use_fuse_rope=False) |
| 252 | + self.test_mp(is_use_fuse_rope=False) |
| 253 | + self.test_dp(is_use_fuse_rope=True) |
| 254 | + self.test_mp(is_use_fuse_rope=True) |
| 255 | + |
| 256 | + |
| 257 | +if __name__ == '__main__': |
| 258 | + TestLlamaRopeSemiAutoParallel().run_test_case() |
0 commit comments