Skip to content

Commit f799631

Browse files
authored
[inference]Add alibi to flash attn function (#5678)
* add alibi to flash attn function * rm redundant modifications
1 parent ef8e4ff commit f799631

File tree

2 files changed

+6
-13
lines changed

2 files changed

+6
-13
lines changed

colossalai/inference/core/engine.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,7 @@ def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy
121121
casuallm = _supported_models[arch](hf_config)
122122
if isinstance(casuallm, AutoModelForCausalLM):
123123
# NOTE(caidi) It's necessary to add half() here, otherwise baichuan13B will overflow the memory.
124-
model = (
125-
AutoModelForCausalLM.from_pretrained(model_or_path, trust_remote_code=True).half().cuda()
126-
)
124+
model = AutoModelForCausalLM.from_pretrained(model_or_path, trust_remote_code=True).half()
127125
else:
128126
model = _supported_models[arch](hf_config)
129127
else:

colossalai/inference/modeling/models/nopadding_baichuan.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ def baichuan_rmsnorm_forward(
7979
TypeError(
8080
"Currently, the variable name for the epsilon of baichuan7B/13B should be 'variance_epsilon' or 'epsilon'."
8181
)
82-
8382
if use_cuda_kernel:
8483
if residual is not None:
8584
inference_ops.fused_add_rms_layernorm(hidden_states, residual, self.weight.data, eps)
@@ -137,6 +136,7 @@ def __init__(
137136
self.alibi_slopes = get_alibi_slopes(config.num_attention_heads, device=attn_qproj_w.device)[
138137
slopes_start : slopes_start + num_heads
139138
].contiguous()
139+
self.alibi_slopes = nn.Parameter(self.alibi_slopes)
140140

141141
@staticmethod
142142
def from_native_module(
@@ -268,19 +268,13 @@ def forward(
268268
block_size = k_cache.size(-2)
269269

270270
if is_prompts:
271-
if (
272-
not is_verifier
273-
and use_cuda_kernel
274-
and query_states.dtype != torch.float32
275-
and use_flash_attn2
276-
and not self.use_alibi_attn
277-
):
271+
if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
278272
# flash attn 2 currently only supports FP16/BF16.
279-
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
273+
if not self.use_alibi_attn:
274+
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
280275
inference_ops.context_kv_cache_memcpy(
281276
key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len
282277
)
283-
284278
attn_output = flash_attn_varlen_func(
285279
query_states,
286280
key_states,
@@ -292,6 +286,7 @@ def forward(
292286
dropout_p=0.0,
293287
softmax_scale=sm_scale,
294288
causal=True,
289+
alibi_slopes=self.alibi_slopes,
295290
)
296291
attn_output = attn_output.view(token_nums, -1)
297292
else:

0 commit comments

Comments
 (0)