Skip to content

Commit 0793bf1

Browse files
authored
[V1][MLA][SW-234434] Enable MLA for V1 - ported from vllm-gaudi (#1628)
https://jira.habana-labs.com/browse/SW-234434 ## Essential Elements of an Effective PR Description Checklist - [x] The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)". - [ ] The test plan, such as providing test command. - [ ] The test results, such as pasting the results comparison before and after, or e2e results ## Purpose Backport vllm-gaudi V1 MLA enabling to vllm-fork for PRC customer request ## Test on Deepseek V2 lite chat ``` HABANA_VISIBLE_DEVICES=all VLLM_SKIP_WARMUP=true \ PT_HPU_LAZY_MODE=1 VLLM_USE_V1=1 VLLM_CONTIGUOUS_PA=False \ lm_eval --model vllm \ --model_args "pretrained=DeepSeek-V2-Lite-Chat/,tensor_parallel_size=1,distributed_executor_backend=mp,trust_remote_code=true,max_model_len=4096,use_v2_block_manager=True,dtype=bfloat16,max_num_seqs=128" \ --tasks gsm8k --num_fewshot "5" \ --batch_size "auto" --log_samples --output_path gsm8k_acc_DeepSeek-V2-Lite-Chat.json ``` |Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.6581|± |0.0131| | | |strict-match | 5|exact_match|↑ |0.6482|± |0.0132| Test on Deepseek R1 ``` HABANA_VISIBLE_DEVICES=all \ VLLM_CONTIGUOUS_PA=False \ VLLM_USE_V1=1 \ PT_HPU_LAZY_MODE=1 \ VLLM_SKIP_WARMUP=true \ PT_HPU_ENABLE_LAZY_COLLECTIVES=true \ PT_HPU_WEIGHT_SHARING=0 \ lm_eval --model vllm \ --model_args "pretrained=DeepSeek-R1,tensor_parallel_size=8,distributed_executor_backend=mp,trust_remote_code=true,max_model_len=16384,use_v2_block_manager=True,dtype=bfloat16,max_num_seqs=128,gpu_memory_utilization=0.9,enable_expert_parallel=True," \ --tasks gsm8k --num_fewshot "8" \ --batch_size "128" --limit 256 --log_samples --output_path gsm8k_acc_${MODEL_NAME}.json ``` vllm (pretrained=DeepSeek-R1,tensor_parallel_size=8,distributed_executor_backend=mp,trust_remote_code=true,max_model_len=16384,use_v2_block_manager=True,dtype=bfloat16,max_num_seqs=128,gpu_memory_utilization=0.9,enable_expert_parallel=True,), gen_kwargs: (None), limit: 256.0, num_fewshot: 8, batch_size: 128 |Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| |gsm8k| 3|flexible-extract| 8|exact_match|↑ |0.9688|± |0.0109| | | |strict-match | 8|exact_match|↑ |0.9609|± |0.0121| <!--- pyml disable-next-line no-emphasis-as-heading --> --------- Signed-off-by: Chendi.Xue <[email protected]>
1 parent bf3e6b0 commit 0793bf1

File tree

3 files changed

+36
-6
lines changed

3 files changed

+36
-6
lines changed

vllm/attention/backends/hpu_attn.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -250,18 +250,35 @@ def forward(
250250
if kv_cache is not None and len(kv_cache) == 2:
251251
self.latent_cache_k(latent_vec_k, kv_cache[0], slot_mapping)
252252
k_cache = kv_cache[0]
253+
else:
254+
k_cache = None
253255

254256
if is_prefill:
255-
return self._forward_prefill(q, k_c_normed, k_pe, attn_metadata,
256-
batch_size)
257+
return self._forward_prefill(q, latent_vec_k, k_cache,
258+
attn_metadata, batch_size)
257259
else:
258260
return self._forward_decode(decode_ql_nope, q_pe, k_cache,
259261
attn_metadata, batch_size)
260262

261263
def _forward_prefill( # type: ignore
262-
self, q: torch.Tensor, k_c_normed: torch.Tensor,
263-
k_pe: torch.Tensor, attn_metadata: HPUAttentionMetadata,
264+
self, q: torch.Tensor, latent_vec_k: torch.Tensor,
265+
k_cache: torch.Tensor, attn_metadata: HPUAttentionMetadata,
264266
batch_size: int) -> torch.Tensor:
267+
##### get prefix cache #####
268+
if attn_metadata.block_list is not None:
269+
current = latent_vec_k
270+
past = self.latent_cache_k.fetch_from_cache(
271+
k_cache.unflatten(0, (-1, attn_metadata.block_size)),
272+
attn_metadata.block_list)
273+
past = past.view(-1, past.shape[-1])
274+
current = torch.concat((past, current), dim=0)
275+
latent_vec_k = current
276+
# =========================== #
277+
278+
k_c_normed, k_pe = latent_vec_k.split(
279+
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
280+
k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim)
281+
265282
kv_nope = self.kv_b_proj(k_c_normed)[0]\
266283
.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
267284
k_nope, v = kv_nope\
@@ -290,11 +307,14 @@ def _forward_prefill( # type: ignore
290307
value=v_padded,
291308
is_causal=True,
292309
attn_bias=attn_metadata.attn_bias,
310+
position_bias=None,
293311
valid_seq_lengths=attn_metadata.seq_lens_tensor,
294312
scale=self.scale,
295313
matmul_qk_op=self.matmul_qk,
296314
softmax_op=self.softmax,
297315
matmul_av_op=self.matmul_av,
316+
keys_fetch_func=self.latent_cache_k.fetch_from_cache,
317+
values_fetch_func = None,
298318
fsdpa_op=self.fused_scaled_dot_product_attention.apply \
299319
if self.fused_scaled_dot_product_attention is not None else None)
300320
attn_output = out.view(batch_size, -1, self.num_heads, q.shape[-1])

vllm/platforms/hpu.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,12 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
4040
dtype: torch.dtype, kv_cache_dtype: Optional[str],
4141
block_size: int, use_v1: bool,
4242
use_mla: bool) -> str:
43-
if use_v1:
43+
if use_v1 and not use_mla:
4444
logger.info("Using HPUAttentionV1 backend.")
4545
return "vllm.v1.attention.backends.hpu_attn.HPUAttentionBackendV1"
46+
if use_v1 and use_mla:
47+
logger.info("Using HPUAttentionMLA backend.")
48+
return "vllm.attention.backends.hpu_attn.HPUMLAAttentionBackend"
4649
if use_mla:
4750
logger.info("Using HPUAttentionMLA backend.")
4851
return "vllm.attention.backends.hpu_attn.HPUMLAAttentionBackend"

vllm/v1/worker/hpu_model_runner.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2349,11 +2349,18 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
23492349
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
23502350
num_blocks + 1, kv_cache_spec.block_size,
23512351
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
2352+
v_cache_shape = None if self.model_config.use_mla \
2353+
else kv_cache_shape
23522354
dtype = kv_cache_spec.dtype
23532355
key_cache = torch.zeros(kv_cache_shape,
23542356
dtype=dtype,
23552357
device=self.device)
2356-
value_cache = torch.zeros_like(key_cache)
2358+
if v_cache_shape is not None:
2359+
value_cache = torch.zeros(v_cache_shape,
2360+
dtype=dtype,
2361+
device=self.device)
2362+
else:
2363+
value_cache = None
23572364
kv_caches[layer_name] = (key_cache, value_cache)
23582365
else:
23592366
# TODO: add new branches when introducing more types of

0 commit comments

Comments
 (0)