Skip to content

Commit 1af1039

Browse files
NickLuccheBoyuanFeng
authored andcommitted
[CI] Fix tests/v1/e2e/test_kv_sharing_fast_prefill.py import on test (vllm-project#22815)
Signed-off-by: NickLucche <[email protected]> Signed-off-by: Boyuan Feng <[email protected]>
1 parent 8e96175 commit 1af1039

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

tests/v1/e2e/test_kv_sharing_fast_prefill.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from vllm.config import CompilationConfig, CompilationLevel
1212
from vllm.distributed import cleanup_dist_env_and_memory
1313
from vllm.forward_context import get_forward_context
14-
from vllm.model_executor.models.gemma3n import Gemma3nForConditionalGeneration
14+
from vllm.model_executor.models.gemma3n_mm import (
15+
Gemma3nForConditionalGeneration)
1516
from vllm.model_executor.models.registry import ModelRegistry
1617
from vllm.model_executor.models.utils import extract_layer_index
1718
from vllm.sequence import IntermediateTensors
@@ -32,12 +33,13 @@ def forward(
3233
inputs_embeds: Optional[torch.Tensor] = None,
3334
**kwargs,
3435
) -> Union[torch.Tensor, IntermediateTensors]:
35-
hidden_states = self.model(input_ids, positions, intermediate_tensors,
36-
inputs_embeds, **kwargs)
36+
hidden_states = super().forward(input_ids, positions,
37+
intermediate_tensors, inputs_embeds,
38+
**kwargs)
3739
attn_metadata = get_forward_context().attn_metadata
3840
# attn_metadata is None during dummy runs
3941
if (attn_metadata is not None
40-
and self.cache_config.kv_sharing_fast_prefill):
42+
and self.language_model.cache_config.kv_sharing_fast_prefill):
4143
assert isinstance(attn_metadata, dict) # true in V1
4244
# Gemma3n-E2B has 30 layers, with last 20 layers being
4345
# cross-decoder layers. Check attention metadata is correct
@@ -52,7 +54,7 @@ def forward(
5254

5355
# Last layer will be a KV sharing layer
5456
layer_attn_metadata = attn_metadata[
55-
self.model.language_model.layers[-1].self_attn.attn.layer_name]
57+
self.language_model.model.layers[-1].self_attn.attn.layer_name]
5658
logits_indices_padded = (layer_attn_metadata.logits_indices_padded)
5759
assert logits_indices_padded is not None
5860
num_logits_indices = layer_attn_metadata.num_logits_indices

0 commit comments

Comments
 (0)