|
| 1 | +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import os |
| 16 | +import random |
| 17 | +import unittest |
| 18 | + |
| 19 | +import numpy as np |
| 20 | + |
| 21 | +import paddle |
| 22 | +from paddle.distributed import fleet |
| 23 | +from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( |
| 24 | + DygraphShardingOptimizer, |
| 25 | + DygraphShardingOptimizerV2, |
| 26 | +) |
| 27 | +from paddle.distributed.fleet.utils.mix_precision_utils import ( |
| 28 | + MixPrecisionLayer, |
| 29 | + MixPrecisionOptimizer, |
| 30 | +) |
| 31 | + |
| 32 | +g_shard_split_param = int(os.environ.get("FLAGS_shard_split_param", 0)) |
| 33 | +g_shard_param_with_color = int( |
| 34 | + os.environ.get("FLAGS_shard_param_with_color", 0) |
| 35 | +) |
| 36 | + |
| 37 | +vocab_size = 20 |
| 38 | +hidden_size = 10 |
| 39 | +inner_size = 8 |
| 40 | +output_size = 10 |
| 41 | +seq_length = 2 |
| 42 | +batch_size = 4 |
| 43 | +STEPS = 10 |
| 44 | + |
| 45 | + |
| 46 | +class SimpleDPNet(paddle.nn.Layer): |
| 47 | + def __init__( |
| 48 | + self, vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 |
| 49 | + ): |
| 50 | + super().__init__() |
| 51 | + self.linear1 = paddle.nn.Linear( |
| 52 | + hidden_size, |
| 53 | + inner_size, |
| 54 | + weight_attr=paddle.framework.ParamAttr( |
| 55 | + initializer=paddle.nn.initializer.Assign(np_fc1) |
| 56 | + ), |
| 57 | + bias_attr=paddle.framework.ParamAttr( |
| 58 | + initializer=paddle.nn.initializer.Constant(0.0) |
| 59 | + ), |
| 60 | + ) |
| 61 | + |
| 62 | + self.linear2 = paddle.nn.Linear( |
| 63 | + inner_size, |
| 64 | + hidden_size, |
| 65 | + weight_attr=paddle.framework.ParamAttr( |
| 66 | + initializer=paddle.nn.initializer.Assign(np_fc2) |
| 67 | + ), |
| 68 | + bias_attr=paddle.framework.ParamAttr( |
| 69 | + initializer=paddle.nn.initializer.Constant(0.0) |
| 70 | + ), |
| 71 | + ) |
| 72 | + |
| 73 | + self.linear3 = paddle.nn.Linear( |
| 74 | + hidden_size, |
| 75 | + output_size, |
| 76 | + weight_attr=paddle.framework.ParamAttr( |
| 77 | + initializer=paddle.nn.initializer.Constant(0.0) |
| 78 | + ), |
| 79 | + bias_attr=paddle.framework.ParamAttr( |
| 80 | + initializer=paddle.nn.initializer.Constant(0.0) |
| 81 | + ), |
| 82 | + ) |
| 83 | + |
| 84 | + self.embedding = paddle.nn.Embedding( |
| 85 | + vocab_size, |
| 86 | + hidden_size, |
| 87 | + weight_attr=paddle.nn.initializer.Constant(value=0.5), |
| 88 | + ) |
| 89 | + |
| 90 | + if g_shard_param_with_color: |
| 91 | + for p in self.linear1.parameters(): |
| 92 | + p.color = {'color': "linear1"} |
| 93 | + |
| 94 | + for p in self.linear2.parameters(): |
| 95 | + p.color = {'color': "linear2"} |
| 96 | + |
| 97 | + for p in self.linear3.parameters(): |
| 98 | + p.color = {'color': "linear3"} |
| 99 | + |
| 100 | + def forward(self, x): |
| 101 | + x = self.embedding(x) |
| 102 | + x = self.linear1(x) |
| 103 | + x = self.linear2(x) |
| 104 | + x = self.linear3(x) |
| 105 | + x = paddle.matmul(x, self.embedding.weight, transpose_y=True) |
| 106 | + return x |
| 107 | + |
| 108 | + |
| 109 | +class TestShardingV2ChunkOffload(unittest.TestCase): |
| 110 | + def setUp(self): |
| 111 | + random.seed(2021) |
| 112 | + np.random.seed(2021) |
| 113 | + paddle.seed(2021) |
| 114 | + |
| 115 | + self.strategy = fleet.DistributedStrategy() |
| 116 | + |
| 117 | + self.strategy.hybrid_configs = { |
| 118 | + "sharding_degree": 2, |
| 119 | + "dp_degree": 1, |
| 120 | + "mp_degree": 1, |
| 121 | + "pp_degree": 1, |
| 122 | + } |
| 123 | + self.strategy.hybrid_configs["sharding_configs"].split_param = True |
| 124 | + self.strategy.hybrid_configs[ |
| 125 | + "sharding_configs" |
| 126 | + ].offload_opt_buffer_size = 0 |
| 127 | + fleet.init(is_collective=True, strategy=self.strategy) |
| 128 | + self.data = [ |
| 129 | + np.random.randint( |
| 130 | + 0, |
| 131 | + vocab_size, |
| 132 | + ( |
| 133 | + batch_size, |
| 134 | + seq_length, |
| 135 | + ), |
| 136 | + ) |
| 137 | + for _ in range(STEPS) |
| 138 | + ] |
| 139 | + |
| 140 | + def train_batch(self, batch, model, optimizer): |
| 141 | + output = model(batch) |
| 142 | + loss = output.mean() |
| 143 | + loss.backward() # do backward |
| 144 | + optimizer.step() # update parameters |
| 145 | + optimizer.clear_grad() |
| 146 | + return loss |
| 147 | + |
| 148 | + def build_optimizer(self, model, strategy=None, Optimizer="adam"): |
| 149 | + clip = paddle.nn.ClipGradByGlobalNorm(0.5) |
| 150 | + if Optimizer == "adam": |
| 151 | + optimizer = paddle.optimizer.AdamW( |
| 152 | + parameters=model.parameters(), |
| 153 | + learning_rate=0.001, |
| 154 | + weight_decay=0.00001, |
| 155 | + grad_clip=clip, |
| 156 | + ) |
| 157 | + else: |
| 158 | + optimizer = paddle.optimizer.Momentum( |
| 159 | + learning_rate=0.001, |
| 160 | + parameters=model.parameters(), |
| 161 | + grad_clip=clip, |
| 162 | + ) |
| 163 | + return optimizer |
| 164 | + |
| 165 | + def build_model_optimizer(self, Optimizer="adam", amp_level=None): |
| 166 | + np_fc1 = np.random.random_sample((hidden_size, inner_size)) |
| 167 | + np_fc2 = np.random.random_sample((inner_size, hidden_size)) |
| 168 | + |
| 169 | + model_a = SimpleDPNet( |
| 170 | + vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 |
| 171 | + ) |
| 172 | + optimizer_a = self.build_optimizer( |
| 173 | + model_a, |
| 174 | + strategy=self.strategy, |
| 175 | + Optimizer=Optimizer, |
| 176 | + ) |
| 177 | + |
| 178 | + model_b = SimpleDPNet( |
| 179 | + vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 |
| 180 | + ) |
| 181 | + optimizer_b = self.build_optimizer( |
| 182 | + model_b, |
| 183 | + strategy=self.strategy, |
| 184 | + Optimizer=Optimizer, |
| 185 | + ) |
| 186 | + |
| 187 | + if amp_level is not None and amp_level == "O2": |
| 188 | + model_a = MixPrecisionLayer(model_a) |
| 189 | + optimizer_a = MixPrecisionOptimizer(optimizer_a) |
| 190 | + model_b = MixPrecisionLayer(model_b) |
| 191 | + optimizer_b = MixPrecisionOptimizer(optimizer_b) |
| 192 | + |
| 193 | + model_a = fleet.distributed_model(model_a) |
| 194 | + optimizer_a = fleet.distributed_optimizer(optimizer_a) |
| 195 | + model_b = fleet.distributed_model(model_b) |
| 196 | + optimizer_b = fleet.distributed_optimizer(optimizer_b) |
| 197 | + |
| 198 | + optimizer_a._set_all_gather_overlap_forward(True, model_a) |
| 199 | + optimizer_b._set_all_gather_overlap_forward(False, model_b) |
| 200 | + return model_a, optimizer_a, model_b, optimizer_b |
| 201 | + |
| 202 | + def sharding_model(self, Optimizer, sharded_accumulators, amp_level=None): |
| 203 | + model_a, optimizer_a, model_b, optimizer_b = self.build_model_optimizer( |
| 204 | + Optimizer=Optimizer, |
| 205 | + amp_level=amp_level, |
| 206 | + ) |
| 207 | + opt_cls = ( |
| 208 | + DygraphShardingOptimizerV2 if True else DygraphShardingOptimizer |
| 209 | + ) |
| 210 | + self.assertTrue(isinstance(optimizer_a._inner_opt, opt_cls)) |
| 211 | + |
| 212 | + for idx in range(STEPS): |
| 213 | + if idx == 2 and paddle.distributed.get_rank() == 0 and not True: |
| 214 | + self.assertTrue( |
| 215 | + set(optimizer_a._inner_opt._inner_opt.state_dict().keys()) |
| 216 | + == sharded_accumulators |
| 217 | + ) |
| 218 | + |
| 219 | + if paddle.distributed.get_rank() == 0: |
| 220 | + batch_sharding = paddle.to_tensor(self.data[idx][:2]) |
| 221 | + else: |
| 222 | + batch_sharding = paddle.to_tensor(self.data[idx][2:]) |
| 223 | + |
| 224 | + batch_single = paddle.to_tensor(self.data[idx]) |
| 225 | + loss_a = self.train_batch(batch_sharding, model_a, optimizer_a) |
| 226 | + loss_b = self.train_batch(batch_single, model_b, optimizer_b) |
| 227 | + |
| 228 | + for j in range(len(model_a.parameters())): |
| 229 | + np.testing.assert_allclose( |
| 230 | + model_a.parameters()[j].numpy(), |
| 231 | + model_b.parameters()[j].numpy(), |
| 232 | + rtol=1e-6, |
| 233 | + ) |
| 234 | + |
| 235 | + def test_all_gather_overlap_forward(self): |
| 236 | + if True: |
| 237 | + sharded_accumulators = { |
| 238 | + 'linear_12.b_0_velocity_0', |
| 239 | + 'linear_13.b_0_velocity_0', |
| 240 | + 'linear_14.b_0_velocity_0', |
| 241 | + 'embedding_4.w_0_velocity_0', |
| 242 | + } |
| 243 | + self.sharding_model( |
| 244 | + Optimizer="Momentum", |
| 245 | + sharded_accumulators=sharded_accumulators, |
| 246 | + amp_level="O2", |
| 247 | + ) |
| 248 | + |
| 249 | + |
| 250 | +if __name__ == "__main__": |
| 251 | + unittest.main() |
0 commit comments