Skip to content

Commit 114bee1

Browse files
authored
Add use of fused_partial_rope op (#10942)
1 parent 6964fc8 commit 114bee1

File tree

1 file changed

+50
-1
lines changed

1 file changed

+50
-1
lines changed

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,11 @@ def swiglu(x, y=None):
106106
x, y = paddle.chunk(x, chunks=2, axis=-1)
107107
return F.silu(x) * y
108108

109+
try:
110+
from paddle.incubate.nn.functional import fused_partial_rope
111+
except ImportError:
112+
fused_partial_rope = None
113+
109114

110115
__all__ = [
111116
"DeepseekV2LMHead",
@@ -1089,7 +1094,7 @@ def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor:
10891094

10901095

10911096
@to_static(backend="CINN")
1092-
def qkv_pre_process(
1097+
def qkv_pre_process_no_fuse(
10931098
q, kv, k_pe, rotary_emb, num_heads, q_head_dim, qk_nope_head_dim, v_head_dim, qk_rope_head_dim, position_ids
10941099
):
10951100
bsz, q_len, _ = q.shape
@@ -1125,6 +1130,50 @@ def qkv_pre_process(
11251130
return query_states, key_states, value_states
11261131

11271132

1133+
@to_static(backend="CINN")
1134+
def rearrange_kv(kv, k_pe, qk_nope_head_dim, num_heads):
1135+
k_nope = kv[..., :qk_nope_head_dim]
1136+
value_states = kv[..., qk_nope_head_dim:]
1137+
1138+
k_pe = k_pe.expand([k_pe.shape[0], k_pe.shape[1], num_heads, k_pe.shape[3]])
1139+
key_states = paddle.concat([k_nope, k_pe], axis=-1)
1140+
1141+
return key_states, value_states
1142+
1143+
1144+
def qkv_pre_process(
1145+
q, kv, k_pe, rotary_emb, num_heads, q_head_dim, qk_nope_head_dim, v_head_dim, qk_rope_head_dim, position_ids
1146+
):
1147+
if (fused_partial_rope is None) or (position_ids is not None):
1148+
return qkv_pre_process_no_fuse(
1149+
q, kv, k_pe, rotary_emb, num_heads, q_head_dim, qk_nope_head_dim, v_head_dim, qk_rope_head_dim, position_ids
1150+
)
1151+
1152+
bsz, q_len, _ = q.shape
1153+
1154+
target_query_shape = [0, 0, num_heads, q_head_dim]
1155+
target_key_value_shape = [0, 0, num_heads, qk_nope_head_dim + v_head_dim]
1156+
1157+
q = q.reshape(shape=target_query_shape)
1158+
kv = kv.reshape(shape=target_key_value_shape)
1159+
k_pe = k_pe.reshape([-1, q_len, 1, qk_rope_head_dim])
1160+
1161+
value_states = kv[..., qk_nope_head_dim:]
1162+
1163+
kv_seq_len = value_states.shape[1]
1164+
1165+
cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
1166+
cos = cos[None, :, None, :]
1167+
sin = sin[None, :, None, :]
1168+
1169+
query_states = fused_partial_rope(q, cos, sin)
1170+
k_pe = fused_partial_rope(k_pe, cos, sin)
1171+
1172+
key_states, value_states = rearrange_kv(kv, k_pe, qk_nope_head_dim, num_heads)
1173+
1174+
return query_states, key_states, value_states
1175+
1176+
11281177
def manul_fwd(
11291178
q_init,
11301179
kv_init,

0 commit comments

Comments
 (0)