Skip to content

Commit 9d1130b

Browse files
authored
support bw split (#10915)
* support bw split * fix bug * polish code
1 parent e911a76 commit 9d1130b

File tree

3 files changed

+118
-21
lines changed

3 files changed

+118
-21
lines changed

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@
8080

8181
FA_VERSION = int(os.getenv("FA_VERSION", 2))
8282

83+
from paddle.distributed.fleet.meta_parallel.zero_bubble_utils import WeightGradStore
84+
8385
from ..fp8_utils import FP8KeepXLinear, FP8Linear, FP8Mlp
8486
from .fp8_linear import Linear
8587

@@ -1084,7 +1086,8 @@ def qkv_pre_process(
10841086
target_key_value_shape = [0, 0, num_heads, qk_nope_head_dim + v_head_dim]
10851087

10861088
q = q.reshape(shape=target_query_shape)
1087-
q_nope, q_pe = paddle.split(q, [qk_nope_head_dim, qk_rope_head_dim], axis=-1)
1089+
q_nope = q[..., :qk_nope_head_dim]
1090+
q_pe = q[..., qk_nope_head_dim:]
10881091

10891092
# DeepSeekV2 kv_lora_rank+qk_rope_head_dim=512+64
10901093

@@ -1094,8 +1097,9 @@ def qkv_pre_process(
10941097

10951098
# self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim = 128+64
10961099
# self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim) = config.qk_nope_head_dim + self.v_head_dim = 128+128
1100+
k_nope = kv[..., :qk_nope_head_dim]
1101+
value_states = kv[..., qk_nope_head_dim:]
10971102

1098-
k_nope, value_states = paddle.split(kv, [qk_nope_head_dim, v_head_dim], axis=-1)
10991103
kv_seq_len = value_states.shape[1]
11001104

11011105
cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
@@ -1434,15 +1438,53 @@ def backward(ctx, dout):
14341438
paddle.base.core._set_has_grad(False)
14351439

14361440
# call up proj
1437-
d_kv_ln_t, d_kv_up_weight = _C_ops.matmul_grad(kv_ln_t, kv_up_weight, d_kv, False, False)
1441+
if hasattr(kv_up_weight, "main_grad"):
1442+
d_kv_ln_t = paddle.matmul(d_kv, kv_up_weight, transpose_y=True)
1443+
1444+
def kv_up_weight_grad(kv_ln_t, d_kv, kv_up_weight):
1445+
1446+
with paddle.no_grad():
1447+
w_grad_t = paddle.matmul( kv_ln_t.reshape([-1, kv_ln_t.shape[-1]]), d_kv.reshape([-1, d_kv.shape[-1]]), transpose_x=True)
1448+
1449+
kv_up_weight.main_grad.add_( w_grad_t )
1450+
1451+
if WeightGradStore.enabled:
1452+
1453+
WeightGradStore.put(partial(kv_up_weight_grad, kv_ln_t, d_kv, kv_up_weight))
1454+
else:
1455+
kv_up_weight_grad(kv_ln_t, d_kv, kv_up_weight)
1456+
1457+
d_kv_up_weight = None
1458+
1459+
else:
1460+
d_kv_ln_t, d_kv_up_weight = _C_ops.matmul_grad(kv_ln_t, kv_up_weight, d_kv, False, False)
1461+
14381462

14391463
d_compressed_kv, d_kv_ln_weight = fused_ln.fused_rms_norm_grad_func(
14401464
compressed_kv, kv_ln_weight, kv_ln_invar, d_kv_ln_t, eps
14411465
)
14421466

14431467
d_kv_init = paddle.concat([d_compressed_kv, d_k_pe], axis=-1)
14441468

1445-
d_q_ln_t, d_q_up_weight = _C_ops.matmul_grad(q_ln_t, q_up_weight, d_q, False, False)
1469+
if hasattr(q_up_weight, "main_grad"):
1470+
d_q_ln_t = paddle.matmul(d_q, q_up_weight, transpose_y=True)
1471+
1472+
def q_up_weight_grad(q_ln_t, d_q, q_up_weight):
1473+
1474+
with paddle.no_grad():
1475+
w_grad_t = paddle.matmul( q_ln_t.reshape([-1, q_ln_t.shape[-1]]), d_q.reshape([-1, d_q.shape[-1]]), transpose_x=True)
1476+
q_up_weight.main_grad.add_( w_grad_t )
1477+
1478+
if WeightGradStore.enabled:
1479+
WeightGradStore.put(partial(q_up_weight_grad, q_ln_t, d_q, q_up_weight))
1480+
else:
1481+
q_up_weight_grad(q_ln_t, d_q, q_up_weight)
1482+
1483+
d_q_up_weight = None
1484+
1485+
else:
1486+
d_q_ln_t, d_q_up_weight = _C_ops.matmul_grad(q_ln_t, q_up_weight, d_q, False, False)
1487+
14461488
d_q_init, d_q_ln_weight = fused_ln.fused_rms_norm_grad_func(q_init, q_ln_weight, q_ln_invar, d_q_ln_t, eps)
14471489

14481490
return d_q_init, d_kv_init, d_q_ln_weight, d_kv_ln_weight, d_q_up_weight, d_kv_up_weight

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@
2727
ScheduleNode,
2828
SharedLayerDesc,
2929
)
30+
from paddle.distributed.fleet.meta_parallel.zero_bubble_utils import (
31+
WeightGradStore
32+
)
33+
34+
try:
35+
from paddle.distributed.fleet.meta_parallel.zero_bubble_utils import EventStore
36+
except ImportError:
37+
EventStore = None
3038
from paddle.distributed.fleet.recompute.recompute import recompute
3139
from paddle.distributed.fleet.utils.sequence_parallel_utils import ScatterOp
3240

@@ -714,7 +722,10 @@ def combine_backward(self, output_grad, previous_event=None, async_finish=False,
714722
return ret
715723

716724
def mlp_backward_dw(self):
717-
self.fp8_fusion_moe_node.mlp_node.backward_dw()
725+
if WeightGradStore.enabled:
726+
WeightGradStore.put(self.fp8_fusion_moe_node.mlp_node.backward_dw)
727+
else:
728+
self.fp8_fusion_moe_node.mlp_node.backward_dw()
718729

719730
def mlp_backward(self, output_grad):
720731
if self.send_mtp_embed:
@@ -914,8 +925,18 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
914925

915926
paddle.base.core.nvprof_nvtx_pop()
916927
paddle.base.core.nvprof_nvtx_push("attn_backward")
928+
assert WeightGradStore.funcs_queue.empty()
929+
WeightGradStore.enabled = True
917930
output_grad = self.backward_node.attn_backward(output_grad)
918931
event_to_wait = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id)
932+
933+
if EventStore is not None:
934+
EventStore.set(event_to_wait)
935+
936+
WeightGradStore.enabled = False
937+
WeightGradStore.flush()
938+
WeightGradStore.pop()
939+
assert WeightGradStore.funcs_queue.empty()
919940

920941
paddle.base.core.nvprof_nvtx_pop()
921942

paddlenlp/transformers/fp8_utils.py

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
from functools import partial
1516

1617
import numpy
1718
import paddle
@@ -27,6 +28,8 @@ def swiglu(x, y=None):
2728
x, y = paddle.chunk(x, chunks=2, axis=-1)
2829
return F.silu(x) * y
2930

31+
from paddle.distributed.fleet.meta_parallel.zero_bubble_utils import WeightGradStore
32+
3033

3134
USE_DS_GEMM = os.getenv("USE_DS_GEMM", "False").lower() == "true"
3235

@@ -239,16 +242,31 @@ def compute_expert_w_grad(
239242
if hasattr(weight, "main_grad"):
240243
if weight.main_grad is None:
241244
weight.main_grad = paddle.zeros(shape=weight.shape, dtype=paddle.float32)
242-
result = FP8LinearFunctionBase.kitchen_gemm(
243-
input_t,
244-
input_t_scale,
245-
dout_t,
246-
dout_t_scale,
247-
is_a_1d_scaled,
248-
is_b_1d_scaled,
249-
weight.main_grad,
250-
rtn_dtype,
251-
)
245+
if WeightGradStore.enabled:
246+
WeightGradStore.put(
247+
partial(FP8LinearFunctionBase.kitchen_gemm,
248+
input_t,
249+
input_t_scale,
250+
dout_t,
251+
dout_t_scale,
252+
is_a_1d_scaled,
253+
is_b_1d_scaled,
254+
weight.main_grad,
255+
rtn_dtype, )
256+
)
257+
result = None
258+
259+
else:
260+
result = FP8LinearFunctionBase.kitchen_gemm(
261+
input_t,
262+
input_t_scale,
263+
dout_t,
264+
dout_t_scale,
265+
is_a_1d_scaled,
266+
is_b_1d_scaled,
267+
weight.main_grad,
268+
rtn_dtype,
269+
)
252270
else:
253271
if weight.grad is None:
254272
weight.grad = paddle.zeros(shape=weight.shape, dtype=paddle.float32)
@@ -288,9 +306,17 @@ def common_fp8_mlp_bwd(do3, x_fp8, x_scale, x_t_fp8, x_t_scale, w1, w2, apply_ba
288306
o2, output_scale_transpose=True, quant_method="1x128", input_transpose=True, return_transpose_only=True
289307
)
290308
if apply_backward_hook:
291-
FP8LinearFunctionBase.compute_expert_w_grad(
292-
o2_t_fp8, o2_t_scale, do3_t_fp8, do3_t_scale, True, True, w2, rtn_dtype=paddle.float32
293-
)
309+
if WeightGradStore.enabled:
310+
WeightGradStore.put(
311+
partial(
312+
FP8LinearFunctionBase.compute_expert_w_grad,
313+
o2_t_fp8, o2_t_scale, do3_t_fp8, do3_t_scale, True, True, w2, rtn_dtype=paddle.float32 )
314+
)
315+
else:
316+
317+
FP8LinearFunctionBase.compute_expert_w_grad(
318+
o2_t_fp8, o2_t_scale, do3_t_fp8, do3_t_scale, True, True, w2, rtn_dtype=paddle.float32
319+
)
294320
else:
295321
dw2 = FP8LinearFunctionBase.kitchen_gemm(
296322
o2_t_fp8, o2_t_scale, do3_t_fp8, do3_t_scale, True, True, rtn_dtype=paddle.float32
@@ -306,9 +332,17 @@ def common_fp8_mlp_bwd(do3, x_fp8, x_scale, x_t_fp8, x_t_scale, w1, w2, apply_ba
306332

307333
# ===== dw1 = deep_gemm(x_t_fp8, do1_t_fp8) =====
308334
if apply_backward_hook:
309-
FP8LinearFunctionBase.compute_expert_w_grad(
310-
x_t_fp8, x_t_scale, do1_t_fp8, do1_t_scale, True, True, w1, rtn_dtype=paddle.float32
311-
)
335+
if WeightGradStore.enabled:
336+
WeightGradStore.put(
337+
partial(
338+
FP8LinearFunctionBase.compute_expert_w_grad,
339+
x_t_fp8, x_t_scale, do1_t_fp8, do1_t_scale, True, True, w1, rtn_dtype=paddle.float32)
340+
)
341+
342+
else:
343+
FP8LinearFunctionBase.compute_expert_w_grad(
344+
x_t_fp8, x_t_scale, do1_t_fp8, do1_t_scale, True, True, w1, rtn_dtype=paddle.float32
345+
)
312346
else:
313347
dw1 = FP8LinearFunctionBase.kitchen_gemm(
314348
x_t_fp8, x_t_scale, do1_t_fp8, do1_t_scale, True, True, rtn_dtype=paddle.float32

0 commit comments

Comments
 (0)