@@ -106,6 +106,11 @@ def swiglu(x, y=None):
106
106
x , y = paddle .chunk (x , chunks = 2 , axis = - 1 )
107
107
return F .silu (x ) * y
108
108
109
+ try :
110
+ from paddle .incubate .nn .functional import fused_partial_rope
111
+ except ImportError :
112
+ fused_partial_rope = None
113
+
109
114
110
115
__all__ = [
111
116
"DeepseekV2LMHead" ,
@@ -1089,7 +1094,7 @@ def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor:
1089
1094
1090
1095
1091
1096
@to_static (backend = "CINN" )
1092
- def qkv_pre_process (
1097
+ def qkv_pre_process_no_fuse (
1093
1098
q , kv , k_pe , rotary_emb , num_heads , q_head_dim , qk_nope_head_dim , v_head_dim , qk_rope_head_dim , position_ids
1094
1099
):
1095
1100
bsz , q_len , _ = q .shape
@@ -1125,6 +1130,50 @@ def qkv_pre_process(
1125
1130
return query_states , key_states , value_states
1126
1131
1127
1132
1133
+ @to_static (backend = "CINN" )
1134
+ def rearrange_kv (kv , k_pe , qk_nope_head_dim , num_heads ):
1135
+ k_nope = kv [..., :qk_nope_head_dim ]
1136
+ value_states = kv [..., qk_nope_head_dim :]
1137
+
1138
+ k_pe = k_pe .expand ([k_pe .shape [0 ], k_pe .shape [1 ], num_heads , k_pe .shape [3 ]])
1139
+ key_states = paddle .concat ([k_nope , k_pe ], axis = - 1 )
1140
+
1141
+ return key_states , value_states
1142
+
1143
+
1144
+ def qkv_pre_process (
1145
+ q , kv , k_pe , rotary_emb , num_heads , q_head_dim , qk_nope_head_dim , v_head_dim , qk_rope_head_dim , position_ids
1146
+ ):
1147
+ if (fused_partial_rope is None ) or (position_ids is not None ):
1148
+ return qkv_pre_process_no_fuse (
1149
+ q , kv , k_pe , rotary_emb , num_heads , q_head_dim , qk_nope_head_dim , v_head_dim , qk_rope_head_dim , position_ids
1150
+ )
1151
+
1152
+ bsz , q_len , _ = q .shape
1153
+
1154
+ target_query_shape = [0 , 0 , num_heads , q_head_dim ]
1155
+ target_key_value_shape = [0 , 0 , num_heads , qk_nope_head_dim + v_head_dim ]
1156
+
1157
+ q = q .reshape (shape = target_query_shape )
1158
+ kv = kv .reshape (shape = target_key_value_shape )
1159
+ k_pe = k_pe .reshape ([- 1 , q_len , 1 , qk_rope_head_dim ])
1160
+
1161
+ value_states = kv [..., qk_nope_head_dim :]
1162
+
1163
+ kv_seq_len = value_states .shape [1 ]
1164
+
1165
+ cos , sin = rotary_emb (value_states , seq_len = kv_seq_len )
1166
+ cos = cos [None , :, None , :]
1167
+ sin = sin [None , :, None , :]
1168
+
1169
+ query_states = fused_partial_rope (q , cos , sin )
1170
+ k_pe = fused_partial_rope (k_pe , cos , sin )
1171
+
1172
+ key_states , value_states = rearrange_kv (kv , k_pe , qk_nope_head_dim , num_heads )
1173
+
1174
+ return query_states , key_states , value_states
1175
+
1176
+
1128
1177
def manul_fwd (
1129
1178
q_init ,
1130
1179
kv_init ,
0 commit comments