@@ -79,7 +79,6 @@ def baichuan_rmsnorm_forward(
79
79
TypeError (
80
80
"Currently, the variable name for the epsilon of baichuan7B/13B should be 'variance_epsilon' or 'epsilon'."
81
81
)
82
-
83
82
if use_cuda_kernel :
84
83
if residual is not None :
85
84
inference_ops .fused_add_rms_layernorm (hidden_states , residual , self .weight .data , eps )
@@ -137,6 +136,7 @@ def __init__(
137
136
self .alibi_slopes = get_alibi_slopes (config .num_attention_heads , device = attn_qproj_w .device )[
138
137
slopes_start : slopes_start + num_heads
139
138
].contiguous ()
139
+ self .alibi_slopes = nn .Parameter (self .alibi_slopes )
140
140
141
141
@staticmethod
142
142
def from_native_module (
@@ -268,19 +268,13 @@ def forward(
268
268
block_size = k_cache .size (- 2 )
269
269
270
270
if is_prompts :
271
- if (
272
- not is_verifier
273
- and use_cuda_kernel
274
- and query_states .dtype != torch .float32
275
- and use_flash_attn2
276
- and not self .use_alibi_attn
277
- ):
271
+ if not is_verifier and use_cuda_kernel and query_states .dtype != torch .float32 and use_flash_attn2 :
278
272
# flash attn 2 currently only supports FP16/BF16.
279
- inference_ops .rotary_embedding (query_states , key_states , cos_sin [0 ], cos_sin [1 ], high_precision )
273
+ if not self .use_alibi_attn :
274
+ inference_ops .rotary_embedding (query_states , key_states , cos_sin [0 ], cos_sin [1 ], high_precision )
280
275
inference_ops .context_kv_cache_memcpy (
281
276
key_states , value_states , k_cache , v_cache , sequence_lengths , cu_seqlens , block_tables , kv_seq_len
282
277
)
283
-
284
278
attn_output = flash_attn_varlen_func (
285
279
query_states ,
286
280
key_states ,
@@ -292,6 +286,7 @@ def forward(
292
286
dropout_p = 0.0 ,
293
287
softmax_scale = sm_scale ,
294
288
causal = True ,
289
+ alibi_slopes = self .alibi_slopes ,
295
290
)
296
291
attn_output = attn_output .view (token_nums , - 1 )
297
292
else :
0 commit comments