Skip to content

Commit 3816130

Browse files
authored
fix opt chunk offload (PaddlePaddle#76323)
* fix opt chunk offload * add test * fix test * fix test * fix UT test
1 parent 2973831 commit 3816130

File tree

5 files changed

+304
-0
lines changed

5 files changed

+304
-0
lines changed

python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,6 +838,9 @@ def _build_comm_buffers(
838838
# here group_size is parameter size (GB)
839839
# optimizer states(float32) size is 6 times as much as parameter(bfloat16) size
840840
offload_buffer_size -= sum(opt_states_sizes)
841+
else:
842+
for param in parameters:
843+
self._slice_params[param.name].is_offload_opt = False
841844

842845
self._comm_buffer_list.append(buffer)
843846

test/collective/fleet/CMakeLists.txt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,21 @@ if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT)
372372
set_tests_properties(test_parallel_dygraph_sharding_parallel
373373
PROPERTIES TIMEOUT "400")
374374
endif()
375+
if((WITH_GPU) AND LOCAL_ALL_PLAT)
376+
bash_test_modules(
377+
test_parallel_dygraph_sharding_parallel_chunkoffload
378+
START_BASH
379+
../../legacy_test/dist_test.sh
380+
TIMEOUT
381+
"600"
382+
LABELS
383+
"RUN_TYPE=DIST"
384+
ENVS
385+
"PADDLE_DIST_UT_PORT=21218;http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python"
386+
)
387+
set_tests_properties(test_parallel_dygraph_sharding_parallel_chunkoffload
388+
PROPERTIES TIMEOUT "600")
389+
endif()
375390
if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT)
376391
bash_test_modules(
377392
test_parallel_dygraph_tensor_parallel
Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import random
17+
import unittest
18+
19+
import numpy as np
20+
21+
import paddle
22+
from paddle.distributed import fleet
23+
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
24+
DygraphShardingOptimizer,
25+
DygraphShardingOptimizerV2,
26+
)
27+
from paddle.distributed.fleet.utils.mix_precision_utils import (
28+
MixPrecisionLayer,
29+
MixPrecisionOptimizer,
30+
)
31+
32+
g_shard_split_param = int(os.environ.get("FLAGS_shard_split_param", 0))
33+
g_shard_param_with_color = int(
34+
os.environ.get("FLAGS_shard_param_with_color", 0)
35+
)
36+
37+
vocab_size = 20
38+
hidden_size = 10
39+
inner_size = 8
40+
output_size = 10
41+
seq_length = 2
42+
batch_size = 4
43+
STEPS = 10
44+
45+
46+
class SimpleDPNet(paddle.nn.Layer):
47+
def __init__(
48+
self, vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2
49+
):
50+
super().__init__()
51+
self.linear1 = paddle.nn.Linear(
52+
hidden_size,
53+
inner_size,
54+
weight_attr=paddle.framework.ParamAttr(
55+
initializer=paddle.nn.initializer.Assign(np_fc1)
56+
),
57+
bias_attr=paddle.framework.ParamAttr(
58+
initializer=paddle.nn.initializer.Constant(0.0)
59+
),
60+
)
61+
62+
self.linear2 = paddle.nn.Linear(
63+
inner_size,
64+
hidden_size,
65+
weight_attr=paddle.framework.ParamAttr(
66+
initializer=paddle.nn.initializer.Assign(np_fc2)
67+
),
68+
bias_attr=paddle.framework.ParamAttr(
69+
initializer=paddle.nn.initializer.Constant(0.0)
70+
),
71+
)
72+
73+
self.linear3 = paddle.nn.Linear(
74+
hidden_size,
75+
output_size,
76+
weight_attr=paddle.framework.ParamAttr(
77+
initializer=paddle.nn.initializer.Constant(0.0)
78+
),
79+
bias_attr=paddle.framework.ParamAttr(
80+
initializer=paddle.nn.initializer.Constant(0.0)
81+
),
82+
)
83+
84+
self.embedding = paddle.nn.Embedding(
85+
vocab_size,
86+
hidden_size,
87+
weight_attr=paddle.nn.initializer.Constant(value=0.5),
88+
)
89+
90+
if g_shard_param_with_color:
91+
for p in self.linear1.parameters():
92+
p.color = {'color': "linear1"}
93+
94+
for p in self.linear2.parameters():
95+
p.color = {'color': "linear2"}
96+
97+
for p in self.linear3.parameters():
98+
p.color = {'color': "linear3"}
99+
100+
def forward(self, x):
101+
x = self.embedding(x)
102+
x = self.linear1(x)
103+
x = self.linear2(x)
104+
x = self.linear3(x)
105+
x = paddle.matmul(x, self.embedding.weight, transpose_y=True)
106+
return x
107+
108+
109+
class TestShardingV2ChunkOffload(unittest.TestCase):
110+
def setUp(self):
111+
random.seed(2021)
112+
np.random.seed(2021)
113+
paddle.seed(2021)
114+
115+
self.strategy = fleet.DistributedStrategy()
116+
117+
self.strategy.hybrid_configs = {
118+
"sharding_degree": 2,
119+
"dp_degree": 1,
120+
"mp_degree": 1,
121+
"pp_degree": 1,
122+
}
123+
self.strategy.hybrid_configs["sharding_configs"].split_param = True
124+
self.strategy.hybrid_configs[
125+
"sharding_configs"
126+
].offload_opt_buffer_size = 0
127+
fleet.init(is_collective=True, strategy=self.strategy)
128+
self.data = [
129+
np.random.randint(
130+
0,
131+
vocab_size,
132+
(
133+
batch_size,
134+
seq_length,
135+
),
136+
)
137+
for _ in range(STEPS)
138+
]
139+
140+
def train_batch(self, batch, model, optimizer):
141+
output = model(batch)
142+
loss = output.mean()
143+
loss.backward() # do backward
144+
optimizer.step() # update parameters
145+
optimizer.clear_grad()
146+
return loss
147+
148+
def build_optimizer(self, model, strategy=None, Optimizer="adam"):
149+
clip = paddle.nn.ClipGradByGlobalNorm(0.5)
150+
if Optimizer == "adam":
151+
optimizer = paddle.optimizer.AdamW(
152+
parameters=model.parameters(),
153+
learning_rate=0.001,
154+
weight_decay=0.00001,
155+
grad_clip=clip,
156+
)
157+
else:
158+
optimizer = paddle.optimizer.Momentum(
159+
learning_rate=0.001,
160+
parameters=model.parameters(),
161+
grad_clip=clip,
162+
)
163+
return optimizer
164+
165+
def build_model_optimizer(self, Optimizer="adam", amp_level=None):
166+
np_fc1 = np.random.random_sample((hidden_size, inner_size))
167+
np_fc2 = np.random.random_sample((inner_size, hidden_size))
168+
169+
model_a = SimpleDPNet(
170+
vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2
171+
)
172+
optimizer_a = self.build_optimizer(
173+
model_a,
174+
strategy=self.strategy,
175+
Optimizer=Optimizer,
176+
)
177+
178+
model_b = SimpleDPNet(
179+
vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2
180+
)
181+
optimizer_b = self.build_optimizer(
182+
model_b,
183+
strategy=self.strategy,
184+
Optimizer=Optimizer,
185+
)
186+
187+
if amp_level is not None and amp_level == "O2":
188+
model_a = MixPrecisionLayer(model_a)
189+
optimizer_a = MixPrecisionOptimizer(optimizer_a)
190+
model_b = MixPrecisionLayer(model_b)
191+
optimizer_b = MixPrecisionOptimizer(optimizer_b)
192+
193+
model_a = fleet.distributed_model(model_a)
194+
optimizer_a = fleet.distributed_optimizer(optimizer_a)
195+
model_b = fleet.distributed_model(model_b)
196+
optimizer_b = fleet.distributed_optimizer(optimizer_b)
197+
198+
optimizer_a._set_all_gather_overlap_forward(True, model_a)
199+
optimizer_b._set_all_gather_overlap_forward(False, model_b)
200+
return model_a, optimizer_a, model_b, optimizer_b
201+
202+
def sharding_model(self, Optimizer, sharded_accumulators, amp_level=None):
203+
model_a, optimizer_a, model_b, optimizer_b = self.build_model_optimizer(
204+
Optimizer=Optimizer,
205+
amp_level=amp_level,
206+
)
207+
opt_cls = (
208+
DygraphShardingOptimizerV2 if True else DygraphShardingOptimizer
209+
)
210+
self.assertTrue(isinstance(optimizer_a._inner_opt, opt_cls))
211+
212+
for idx in range(STEPS):
213+
if idx == 2 and paddle.distributed.get_rank() == 0 and not True:
214+
self.assertTrue(
215+
set(optimizer_a._inner_opt._inner_opt.state_dict().keys())
216+
== sharded_accumulators
217+
)
218+
219+
if paddle.distributed.get_rank() == 0:
220+
batch_sharding = paddle.to_tensor(self.data[idx][:2])
221+
else:
222+
batch_sharding = paddle.to_tensor(self.data[idx][2:])
223+
224+
batch_single = paddle.to_tensor(self.data[idx])
225+
loss_a = self.train_batch(batch_sharding, model_a, optimizer_a)
226+
loss_b = self.train_batch(batch_single, model_b, optimizer_b)
227+
228+
for j in range(len(model_a.parameters())):
229+
np.testing.assert_allclose(
230+
model_a.parameters()[j].numpy(),
231+
model_b.parameters()[j].numpy(),
232+
rtol=1e-6,
233+
)
234+
235+
def test_all_gather_overlap_forward(self):
236+
if True:
237+
sharded_accumulators = {
238+
'linear_12.b_0_velocity_0',
239+
'linear_13.b_0_velocity_0',
240+
'linear_14.b_0_velocity_0',
241+
'embedding_4.w_0_velocity_0',
242+
}
243+
self.sharding_model(
244+
Optimizer="Momentum",
245+
sharded_accumulators=sharded_accumulators,
246+
amp_level="O2",
247+
)
248+
249+
250+
if __name__ == "__main__":
251+
unittest.main()
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import unittest
17+
18+
from legacy_test.test_parallel_dygraph_dataparallel import (
19+
TestMultipleAccelerators,
20+
)
21+
22+
23+
class TestHybridParallelShardingV2ChunkOffload(TestMultipleAccelerators):
24+
# check sharding logic as well as the accuracy with single mode
25+
def test_hybrid_parallel_sharding_v2_chunk_offload(self):
26+
# test sharding v2 chunk offload
27+
os.environ["FLAGS_shard_split_param"] = "1"
28+
self.run_mnist_2accelerators(
29+
'hybrid_parallel_sharding_model_chunkoffload.py'
30+
)
31+
32+
33+
if __name__ == "__main__":
34+
unittest.main()

tools/parallel_UT_rule.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,7 @@
519519
'test_new_group',
520520
'test_imperative_signal_handler',
521521
'test_parallel_dygraph_sharding_parallel',
522+
'test_parallel_dygraph_sharding_parallel_chunkoffload',
522523
'test_dist_hapi_model',
523524
'test_dist_mnist_gradient_merge',
524525
'test_rnn_dp',

0 commit comments

Comments
 (0)