Skip to content

Commit dcb37aa

Browse files
iosmersXrekitianhaodongbd
authored
fix stage2 main_grad acc bug (#59142) (#60030)
* fix stage2 main_grad acc bug * update code according to suggest * scale in opt * merge grad scale * add note * delete debug info * keep offload unchange * Optimize the BF16 unittest of sharding stage2 and stage3. * fix stage3 bug * add fp16 judge * add init * add fp16 * fix grad clip * add if data type is fp16 * change if location * delete fault arg * add enmu.value --------- Co-authored-by: Liu Yiqun <[email protected]> Co-authored-by: tianhaodongbd <[email protected]>
1 parent 92f1fb7 commit dcb37aa

File tree

5 files changed

+273
-283
lines changed

5 files changed

+273
-283
lines changed

paddle/phi/kernels/fusion/gpu/fused_linear_param_grad_add_kernel.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,13 @@ void FusedLinearParamGradAdd(const Context &ctx,
134134

135135
bool use_addto = false;
136136
if (dweight_out) {
137+
if (dweight_out->dtype() == phi::DataType::FLOAT16) {
138+
LOG_FIRST_N(WARNING, 1)
139+
<< "fused_linear_param_grad_add op may have problems when "
140+
"master_grad is not enabled and use fp16-O2 in stage2, users "
141+
"should pay attention to the correctness of the result of the "
142+
"grad accumulation in stage2.";
143+
}
137144
if (dweight) {
138145
use_addto = true;
139146
*dweight_out = dweight.get();

python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -194,13 +194,26 @@ def __init__(
194194
and hcg.get_parallel_mode() is not ParallelMode.DATA_PARALLEL
195195
and not offload
196196
):
197-
self._optim._grad_clip = HybridParallelClipGrad(
198-
self._optim._grad_clip, hcg
199-
)
197+
if self.use_main_grad:
198+
self._optim._inner_opt._grad_clip = HybridParallelClipGrad(
199+
self._optim._inner_opt._grad_clip, hcg
200+
)
201+
else:
202+
self._optim._grad_clip = HybridParallelClipGrad(
203+
self._optim._grad_clip, hcg
204+
)
200205
else:
201-
self._optim._grad_clip = GroupShardedClipGrad(
202-
self._optim._grad_clip, paddle.get_device(), self._group
203-
)
206+
if self.use_main_grad:
207+
self._optim._inner_opt._grad_clip = GroupShardedClipGrad(
208+
self._optim._inner_opt._grad_clip,
209+
paddle.get_device(),
210+
self._group,
211+
)
212+
else:
213+
self._optim._grad_clip = GroupShardedClipGrad(
214+
self._optim._grad_clip, paddle.get_device(), self._group
215+
)
216+
204217
if self._optim._parameter_list and isinstance(
205218
self._optim._parameter_list[0], dict
206219
):

python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ def __init__(
155155
# Set backward pass hooks
156156
self._bw_hooks = []
157157

158+
self.scale_in_opt = False
159+
158160
# TODO (Baibaifan) Set tasks flow support asynchronous communicate
159161
# self._tasks_flow = deque()
160162

@@ -232,13 +234,20 @@ def _clear_gradients(self):
232234

233235
def _grad_scale(self):
234236
"""
235-
Before the optimization, scale the gradients before allreduce of dp_group.
237+
this function will do 2 things:
238+
1. Before the optimization, scale main_grad to support gradient merge if param has main_grad, or to support fused_linear_param_grad_add gradient merge.
239+
2. Before the optimization, scale the gradients before allreduce of dp_group.
236240
"""
237241

238-
if self._dp_group is None or self._dp_group.nranks <= 1:
239-
return
242+
need_dp_scale = self._dp_group is not None and self._dp_group.nranks > 1
243+
if self.scale_in_opt:
244+
scale_factor = self._world_size_scaling
240245
else:
241-
scale_factor = 1.0 / (self._dp_group.nranks)
246+
scale_factor = 1.0
247+
248+
if need_dp_scale:
249+
dp_scale_factor = 1.0 / (self._dp_group.nranks)
250+
scale_factor = scale_factor * dp_scale_factor
242251

243252
# Scale grad storages
244253
for dtype in self._grad_storages.keys():
@@ -249,7 +258,6 @@ def _grad_scale(self):
249258
self._grad_storages[dtype][self._rank].buffer.scale_(
250259
scale=scale_factor
251260
)
252-
253261
# Scale grads of params
254262
with paddle.no_grad():
255263
for param in self._trainable_params:
@@ -258,11 +266,14 @@ def _grad_scale(self):
258266
param.main_grad.scale_(scale=scale_factor)
259267
elif param.grad is not None:
260268
param.grad.scale_(scale=scale_factor)
261-
# param._reset_grad_inplace_version(True)
262269

263-
# Scale grads of master params with offload strategy
270+
# Scale grads of master params with offload strategy
264271
if self._offload:
265-
self._sharding_optimizers[0]._offload_scale_grad(scale_factor)
272+
if need_dp_scale is False:
273+
return
274+
self._sharding_optimizers[0]._offload_scale_grad(
275+
scale=1.0 / (self._dp_group.nranks)
276+
)
266277

267278
def _init_internal_storage(self, needs_fresh):
268279
"""
@@ -379,15 +390,21 @@ def _set_reduce_overlap(self, reduce_overlap):
379390
def _get_scaled_grad_fn(self, param):
380391
@paddle.autograd.no_grad()
381392
def scale(grad):
382-
if hasattr(param, "main_grad"):
383-
param.main_grad.scale_(self._world_size_scaling)
384-
else:
385-
if grad is not None and grad._is_initialized():
393+
# do gradient scale separately
394+
# For grad scale, we need to do it in the backward hook due to fp16 may overflow if we first add grad and then scale
395+
# For main_grad scale and fused_linear_param_grad_add, we do scale in the optimizer.
396+
if not self.scale_in_opt:
397+
if (
398+
not hasattr(param, "main_grad")
399+
and grad is not None
400+
and grad.dtype == Type.fp16.value
401+
):
402+
assert (
403+
grad._is_initialized()
404+
), "grad should be initialized in stage2"
386405
grad.scale_(self._world_size_scaling)
387406
else:
388-
assert param.grad is not None
389-
assert param.grad._is_initialized()
390-
param.grad.scale_(self._world_size_scaling)
407+
self.scale_in_opt = True
391408

392409
return scale
393410

0 commit comments

Comments
 (0)