|
| 1 | +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import os |
| 16 | + |
| 17 | +import numpy as np |
| 18 | +from auto_parallel.semi_auto_parallel_simple_net import ( |
| 19 | + TestSimpleNetForSemiAutoParallel, |
| 20 | + create_numpy_like_random, |
| 21 | +) |
| 22 | + |
| 23 | +import paddle |
| 24 | +import paddle.distributed as dist |
| 25 | +from paddle import nn |
| 26 | +from paddle.distributed import Replicate, Shard |
| 27 | + |
| 28 | +BATCH_SIZE = 8 |
| 29 | +SEQUENCE_LEN = 512 |
| 30 | +HIDDEN_SIZE = 1024 |
| 31 | +NUM_HEAD = 64 |
| 32 | +HEAD_DIM = 16 |
| 33 | +CLASS_NUM = 10 |
| 34 | + |
| 35 | +np.set_printoptions(threshold=np.inf) |
| 36 | + |
| 37 | + |
| 38 | +class DemoNet(nn.Layer): |
| 39 | + def __init__(self, param_prefix="", is_sp=False, is_dp=False): |
| 40 | + super().__init__() |
| 41 | + |
| 42 | + if is_dp: |
| 43 | + self.pp0_mesh = dist.ProcessMesh( |
| 44 | + [[0, 1], [2, 3]], dim_names=["dp", "mp"] |
| 45 | + ) |
| 46 | + self.pp1_mesh = dist.ProcessMesh( |
| 47 | + [[4, 5], [6, 7]], dim_names=["dp", "mp"] |
| 48 | + ) |
| 49 | + self.placement0 = [Replicate(), Shard(1)] |
| 50 | + self.placement1 = [Replicate(), Shard(0)] |
| 51 | + self.sp_reshard_placement0 = [Shard(1), Shard(0)] |
| 52 | + self.sp_reshard_placement1 = [Shard(1), Replicate()] |
| 53 | + else: |
| 54 | + self.pp0_mesh = dist.ProcessMesh([0, 1], dim_names=["mp"]) |
| 55 | + self.pp1_mesh = dist.ProcessMesh([2, 3], dim_names=["mp"]) |
| 56 | + self.placement0 = [Shard(1)] |
| 57 | + self.placement1 = [Shard(0)] |
| 58 | + self.sp_reshard_placement0 = [Shard(0)] |
| 59 | + self.sp_reshard_placement1 = [Replicate()] |
| 60 | + |
| 61 | + self.is_sp = is_sp |
| 62 | + self.is_dp = is_dp |
| 63 | + |
| 64 | + self.norm = nn.LayerNorm(HIDDEN_SIZE, epsilon=1e-4) |
| 65 | + self.linear_0_weight = dist.shard_tensor( |
| 66 | + self.create_parameter( |
| 67 | + shape=[HEAD_DIM, 4 * HIDDEN_SIZE], |
| 68 | + attr=create_numpy_like_random(param_prefix + "w_0"), |
| 69 | + dtype=paddle.float32, |
| 70 | + is_bias=False, |
| 71 | + ), |
| 72 | + self.pp0_mesh, |
| 73 | + self.placement0, |
| 74 | + ) |
| 75 | + |
| 76 | + self.linear_1_weight = dist.shard_tensor( |
| 77 | + self.create_parameter( |
| 78 | + shape=[4 * HIDDEN_SIZE, HEAD_DIM], |
| 79 | + attr=create_numpy_like_random(param_prefix + "w_1"), |
| 80 | + dtype=paddle.float32, |
| 81 | + is_bias=False, |
| 82 | + ), |
| 83 | + self.pp0_mesh, |
| 84 | + self.placement1, |
| 85 | + ) |
| 86 | + |
| 87 | + self.linear_2_weight = dist.shard_tensor( |
| 88 | + self.create_parameter( |
| 89 | + shape=[HIDDEN_SIZE, 4 * HIDDEN_SIZE], |
| 90 | + attr=create_numpy_like_random(param_prefix + "w_2"), |
| 91 | + dtype=paddle.float32, |
| 92 | + is_bias=False, |
| 93 | + ), |
| 94 | + self.pp1_mesh, |
| 95 | + self.placement0, |
| 96 | + ) |
| 97 | + |
| 98 | + self.linear_3_weight = dist.shard_tensor( |
| 99 | + self.create_parameter( |
| 100 | + shape=[4 * HIDDEN_SIZE, CLASS_NUM], |
| 101 | + attr=create_numpy_like_random(param_prefix + "w_3"), |
| 102 | + dtype=paddle.float32, |
| 103 | + is_bias=False, |
| 104 | + ), |
| 105 | + self.pp1_mesh, |
| 106 | + self.placement1, |
| 107 | + ) |
| 108 | + |
| 109 | + def forward(self, x): |
| 110 | + # Layer 0 |
| 111 | + tgt = paddle.transpose(x, [1, 0, 2]) |
| 112 | + out = paddle.reshape(x, [BATCH_SIZE, SEQUENCE_LEN, NUM_HEAD, HEAD_DIM]) |
| 113 | + # [BATCH_SIZE, SEQUENCE_LEN, NUM_HEAD, HEAD_DIM] -> [BATCH_SIZE, NUM_HEAD, SEQUENCE_LEN, HEAD_DIM] |
| 114 | + out = paddle.transpose(out, [0, 2, 1, 3]) |
| 115 | + out = paddle.matmul(out, self.linear_0_weight) |
| 116 | + out = paddle.matmul(out, self.linear_1_weight) |
| 117 | + out = paddle.transpose(out, [2, 0, 1, 3]) |
| 118 | + out = paddle.reshape(out, [SEQUENCE_LEN, BATCH_SIZE, HIDDEN_SIZE]) |
| 119 | + |
| 120 | + # SP Region, should be reduce_scatter |
| 121 | + if self.is_sp: |
| 122 | + out = dist.reshard(out, self.pp0_mesh, self.sp_reshard_placement0) |
| 123 | + |
| 124 | + # out = out + tgt |
| 125 | + out = self.norm(out) |
| 126 | + |
| 127 | + out = dist.reshard(out, self.pp1_mesh, self.sp_reshard_placement1) |
| 128 | + |
| 129 | + out = paddle.matmul(out, self.linear_2_weight) |
| 130 | + out = paddle.matmul(out, self.linear_3_weight) |
| 131 | + out = paddle.transpose(out, [1, 0, 2]) |
| 132 | + |
| 133 | + return out |
| 134 | + |
| 135 | + |
| 136 | +class TestSimpleNetHybridStrategyForSemiAutoParallel( |
| 137 | + TestSimpleNetForSemiAutoParallel |
| 138 | +): |
| 139 | + def __init__(self): |
| 140 | + self._dtype = os.getenv("dtype") |
| 141 | + self._backend = os.getenv("backend") |
| 142 | + self._seed = eval(os.getenv("seed")) |
| 143 | + self._is_dp = os.getenv("is_dp") == "true" |
| 144 | + if self._is_dp: |
| 145 | + self.pp0_mesh = dist.ProcessMesh( |
| 146 | + [[0, 1], [2, 3]], dim_names=["dp", "mp"] |
| 147 | + ) |
| 148 | + |
| 149 | + paddle.set_device(self._backend) |
| 150 | + |
| 151 | + self.set_random_seed(self._seed) |
| 152 | + self.init_single_card_net_result() |
| 153 | + |
| 154 | + def init_single_card_net_result(self): |
| 155 | + self.set_random_seed(self._seed) |
| 156 | + self.base_loss, self.base_parameters = self.run_dynamic( |
| 157 | + DemoNet("demo_weight", is_sp=False, is_dp=self._is_dp), is_sp=False |
| 158 | + ) |
| 159 | + |
| 160 | + def init_input_data(self): |
| 161 | + image = np.random.randn(BATCH_SIZE, SEQUENCE_LEN, HIDDEN_SIZE).astype( |
| 162 | + 'float32' |
| 163 | + ) |
| 164 | + label = np.random.randn(BATCH_SIZE, SEQUENCE_LEN, CLASS_NUM).astype( |
| 165 | + 'float32' |
| 166 | + ) |
| 167 | + |
| 168 | + return paddle.to_tensor(image), paddle.to_tensor(label) |
| 169 | + |
| 170 | + def check_tensor_eq(self, a, b, rtol=1e-7, atol=0, verbose=True): |
| 171 | + np1 = a.astype("float32").numpy() |
| 172 | + np2 = b.astype("float32").numpy() |
| 173 | + np.testing.assert_allclose( |
| 174 | + np1, np2, rtol=rtol, atol=atol, verbose=verbose |
| 175 | + ) |
| 176 | + |
| 177 | + def run_dynamic(self, layer, is_sp=False): |
| 178 | + # create loss |
| 179 | + loss_fn = nn.MSELoss() |
| 180 | + # run forward and backward |
| 181 | + opt = paddle.optimizer.AdamW( |
| 182 | + learning_rate=0.1, parameters=layer.parameters() |
| 183 | + ) |
| 184 | + for _ in range(5): |
| 185 | + image, label = self.init_input_data() |
| 186 | + if self._is_dp: |
| 187 | + image = dist.shard_tensor( |
| 188 | + image, self.pp0_mesh, [Shard(0), Replicate()] |
| 189 | + ) |
| 190 | + |
| 191 | + out = layer(image) |
| 192 | + |
| 193 | + loss = loss_fn(out, label) |
| 194 | + loss.backward() |
| 195 | + |
| 196 | + opt.step() |
| 197 | + return loss, layer.parameters() |
| 198 | + |
| 199 | + def test_dp_mp_sp_demo_net(self): |
| 200 | + self.set_random_seed(self._seed) |
| 201 | + model = DemoNet("dp_mp_hybrid_strategy", is_sp=True, is_dp=self._is_dp) |
| 202 | + |
| 203 | + ( |
| 204 | + self.dp_mp_sp_loss, |
| 205 | + self.dp_mp_sp_parameters, |
| 206 | + ) = self.run_dynamic(model, is_sp=True) |
| 207 | + |
| 208 | + self.check_tensor_eq(self.dp_mp_sp_loss, self.base_loss) |
| 209 | + for param, param_base in zip( |
| 210 | + self.dp_mp_sp_parameters, self.base_parameters |
| 211 | + ): |
| 212 | + if param.grad is not None: |
| 213 | + self.check_tensor_eq(param, param_base) |
| 214 | + self.check_tensor_eq(param.grad, param_base.grad) |
| 215 | + |
| 216 | + def run_test_case(self): |
| 217 | + self.test_dp_mp_sp_demo_net() |
| 218 | + |
| 219 | + |
| 220 | +if __name__ == '__main__': |
| 221 | + TestSimpleNetHybridStrategyForSemiAutoParallel().run_test_case() |
0 commit comments