Skip to content

Commit bb30b41

Browse files
noooopmaxdebayser
authored andcommitted
[Bugfix] Fix ModernBert load & Enable sliding window attention for bidirectional attention. (vllm-project#22637)
Signed-off-by: wang.yuqi <[email protected]> Signed-off-by: Max de Bayser <[email protected]> Co-authored-by: Max de Bayser <[email protected]> Signed-off-by: Boyuan Feng <[email protected]>
1 parent 613d5bc commit bb30b41

File tree

4 files changed

+101
-59
lines changed

4 files changed

+101
-59
lines changed

tests/models/language/pooling/test_gte.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44

55
import pytest
66

7-
from ...utils import (CLSPoolingEmbedModelInfo, EmbedModelInfo,
8-
LASTPoolingEmbedModelInfo, check_transformers_version)
7+
from ...utils import (CLSPoolingEmbedModelInfo, CLSPoolingRerankModelInfo,
8+
EmbedModelInfo, LASTPoolingEmbedModelInfo,
9+
RerankModelInfo, check_transformers_version)
910
from .embed_utils import correctness_test_embed_models
10-
from .mteb_utils import mteb_test_embed_models
11+
from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models
1112

1213
MODELS = [
1314
########## BertModel
@@ -58,6 +59,14 @@
5859
enable_test=False),
5960
]
6061

62+
RERANK_MODELS = [
63+
# classifier_pooling: mean
64+
CLSPoolingRerankModelInfo(
65+
"Alibaba-NLP/gte-reranker-modernbert-base",
66+
architecture="ModernBertForSequenceClassification",
67+
enable_test=True),
68+
]
69+
6170

6271
@pytest.mark.parametrize("model_info", MODELS)
6372
def test_embed_models_mteb(hf_runner, vllm_runner,
@@ -88,3 +97,9 @@ def test_embed_models_correctness(hf_runner, vllm_runner,
8897

8998
correctness_test_embed_models(hf_runner, vllm_runner, model_info,
9099
example_prompts, vllm_extra_kwargs)
100+
101+
102+
@pytest.mark.parametrize("model_info", RERANK_MODELS)
103+
def test_rerank_models_mteb(hf_runner, vllm_runner,
104+
model_info: RerankModelInfo) -> None:
105+
mteb_test_rerank_models(hf_runner, vllm_runner, model_info)

vllm/model_executor/models/modernbert.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@
2626
from vllm.sequence import IntermediateTensors
2727
from vllm.tasks import PoolingTask
2828

29-
from .interfaces import (SupportsCrossEncoding, SupportsV0Only,
30-
default_pooling_type)
29+
from .interfaces import SupportsCrossEncoding, default_pooling_type
3130
from .utils import WeightsMapper, maybe_prefix
3231

3332

@@ -93,16 +92,14 @@ def __init__(self,
9392
bias=config.attention_bias,
9493
)
9594

95+
sliding_window = None
9696
if layer_id % config.global_attn_every_n_layers != 0:
97-
self.local_attention = (config.local_attention // 2,
98-
config.local_attention // 2)
97+
sliding_window = config.local_attention // 2
98+
rope_theta = config.local_rope_theta if config.local_rope_theta \
99+
is not None else config.global_rope_theta
99100
else:
100-
self.local_attention = (-1, -1)
101+
rope_theta = config.global_rope_theta
101102

102-
rope_theta = config.global_rope_theta
103-
if self.local_attention != (
104-
-1, -1) and config.local_rope_theta is not None:
105-
rope_theta = config.local_rope_theta
106103
self.rotary_emb = ModernBertRotaryEmbedding(config=config,
107104
head_size=self.head_dim,
108105
dim=self.head_dim,
@@ -111,7 +108,8 @@ def __init__(self,
111108
self.head_dim,
112109
self.scaling,
113110
prefix=f"{layer_id}.attn",
114-
attn_type=AttentionType.ENCODER_ONLY)
111+
attn_type=AttentionType.ENCODER_ONLY,
112+
per_layer_sliding_window=sliding_window)
115113
self.Wo = RowParallelLinear(config.hidden_size,
116114
config.hidden_size,
117115
bias=config.attention_bias)
@@ -278,6 +276,7 @@ def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
278276
return self.pooling.get_pooling_updates(task)
279277

280278
def _head(self, pooled_output: torch.Tensor):
279+
pooled_output = pooled_output.to(self.dense.weight.dtype)
281280
return self.norm(self.act(self.dense(pooled_output)))
282281

283282
def forward(
@@ -296,8 +295,7 @@ def forward(
296295

297296

298297
@default_pooling_type("CLS")
299-
class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
300-
SupportsCrossEncoding):
298+
class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
301299

302300
is_pooling_model = True
303301

@@ -308,6 +306,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
308306
self.model = ModernBertModel(vllm_config=vllm_config,
309307
prefix=maybe_prefix(prefix, "modernbert"))
310308
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
309+
self.pooling = ModernBertPooler(config)
311310

312311
pooler_config = vllm_config.model_config.pooler_config
313312
assert pooler_config is not None
@@ -317,14 +316,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
317316
Pooler.for_encode(pooler_config),
318317
"classify":
319318
ClassifierPooler(
320-
pooling=ModernBertPooler(config),
319+
pooling=self.pooling,
321320
classifier=self.classifier,
322321
act_fn=ClassifierPooler.act_fn_for_seq_cls(
323322
vllm_config.model_config),
324323
),
325324
"score":
326325
ClassifierPooler(
327-
pooling=ModernBertPooler(config),
326+
pooling=self.pooling,
328327
classifier=self.classifier,
329328
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
330329
vllm_config.model_config),
@@ -353,7 +352,7 @@ def weight_filter():
353352
default_weight_loader)
354353
weight_loader(param, loaded_weight)
355354
if name.startswith("head"):
356-
param = params_dict["_pooler.pooler." + name[len("head") + 1:]]
355+
param = params_dict["pooling." + name[len("head") + 1:]]
357356
weight_loader = getattr(param, "weight_loader",
358357
default_weight_loader)
359358
weight_loader(param, loaded_weight)
@@ -368,5 +367,5 @@ def forward(
368367
return self.model(
369368
input_ids=input_ids,
370369
inputs_embeds=inputs_embeds,
371-
position_ids=positions,
370+
positions=positions,
372371
)

vllm/v1/attention/backends/flash_attn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,8 @@ def __init__(
384384
self.alibi_slopes = alibi_slopes
385385
if sliding_window is None:
386386
self.sliding_window = (-1, -1)
387+
elif attn_type == AttentionType.ENCODER_ONLY:
388+
self.sliding_window = (sliding_window - 1, sliding_window - 1)
387389
else:
388390
self.sliding_window = (sliding_window - 1, 0)
389391
self.kv_cache_dtype = kv_cache_dtype

vllm/v1/worker/gpu_model_runner.py

Lines changed: 66 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -826,7 +826,8 @@ def _prepare_inputs(
826826
# Prepare encoder attention metadata separately
827827
# (encoder layers are not in KV cache groups)
828828
if self.is_encoder_only_model:
829-
common_attn_metadata, encoder_attn_metadata = \
829+
830+
per_layer_metadata = \
830831
self._build_encoder_only_attn_metadata(
831832
scheduler_output)
832833

@@ -835,6 +836,8 @@ def _prepare_inputs(
835836
self.vllm_config, Attention)
836837
for layer_name, attn_module in attention_layers.items():
837838
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
839+
common_attn_metadata, encoder_attn_metadata =\
840+
per_layer_metadata[layer_name]
838841
attn_metadata[layer_name] = encoder_attn_metadata
839842

840843
# Prepare the attention metadata for each KV cache group and make layers
@@ -2686,30 +2689,41 @@ def create_attn_groups(
26862689
# Check if model is encoder-only
26872690
block_size = self.vllm_config.cache_config.block_size
26882691
use_mla = self.vllm_config.model_config.use_mla
2689-
attn_specs = list[AttentionSpec]()
2690-
for attn_module in attn_layers.values():
2692+
attn_specs: dict[AttentionSpec, list[str]] = defaultdict(list)
2693+
for layer_name, attn_module in attn_layers.items():
26912694

26922695
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
2693-
assert attn_module.sliding_window is None, "Sliding "
2694-
"window attention is not supported for encoder-only models"
2695-
2696-
attn_specs.append(
2697-
FullAttentionSpec(block_size=block_size,
2698-
num_kv_heads=attn_module.num_kv_heads,
2699-
head_size=attn_module.head_size,
2700-
dtype=self.kv_cache_dtype,
2701-
use_mla=use_mla))
2696+
if attn_module.sliding_window is None:
2697+
attn_spec: AttentionSpec = FullAttentionSpec(
2698+
block_size=block_size,
2699+
num_kv_heads=attn_module.num_kv_heads,
2700+
head_size=attn_module.head_size,
2701+
dtype=self.kv_cache_dtype,
2702+
use_mla=use_mla)
2703+
else:
2704+
attn_spec = SlidingWindowSpec(
2705+
block_size=block_size,
2706+
num_kv_heads=attn_module.num_kv_heads,
2707+
head_size=attn_module.head_size,
2708+
dtype=self.kv_cache_dtype,
2709+
sliding_window=attn_module.sliding_window,
2710+
use_mla=use_mla)
2711+
attn_specs[attn_spec].append(layer_name)
2712+
27022713
else:
27032714
raise ValueError("Expected only encoder-only layers")
27042715

27052716
if len(attn_specs) > 0:
2706-
assert len(attn_specs) == len(attn_layers), \
2707-
"All or none of the layers are expected to be encoder-only"
2717+
total_layers = 0
2718+
for attn_spec, layer_names in attn_specs.items():
27082719

2709-
attn_backends = get_attn_backends_for_layers(attn_layers.keys())
2720+
attn_backends = get_attn_backends_for_layers(layer_names)
2721+
total_layers += len(layer_names)
27102722

2711-
self.attn_groups.append(
2712-
create_attn_groups(attn_backends, attn_specs[0]))
2723+
self.attn_groups.append(
2724+
create_attn_groups(attn_backends, attn_spec))
2725+
assert total_layers == len(attn_layers), \
2726+
"All or none of the layers are expected to be encoder-only"
27132727
self.is_encoder_only_model = True
27142728

27152729
def calculate_reorder_batch_threshold(self) -> None:
@@ -3074,7 +3088,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
30743088

30753089
def _build_encoder_only_attn_metadata(
30763090
self, scheduler_output: "SchedulerOutput") -> \
3077-
tuple[CommonAttentionMetadata, Any]:
3091+
dict[str, tuple[CommonAttentionMetadata, Any]]:
30783092
"""Prepare encoder attention metadata for encoder-only models.
30793093
30803094
Args:
@@ -3091,33 +3105,45 @@ def _build_encoder_only_attn_metadata(
30913105
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
30923106
max_num_scheduled_tokens = max(tokens)
30933107

3094-
# Use the first attention metadata builder
3095-
# to create encoder attention metadata
3096-
builder = self.attn_groups[0][0].metadata_builder
3097-
30983108
dummy_block_table = torch.zeros((num_reqs, 1),
30993109
dtype=torch.int32,
31003110
device=self.device)
31013111
dummy_slot_mapping = torch.zeros((total_num_scheduled_tokens, ),
31023112
dtype=torch.int32,
31033113
device=self.device)
31043114

3105-
common_metadata = CommonAttentionMetadata(
3106-
query_start_loc=self.query_start_loc[:num_reqs + 1],
3107-
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
3108-
seq_lens=self.seq_lens[:num_reqs],
3109-
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
3110-
num_computed_tokens_cpu=self.input_batch.
3111-
num_computed_tokens_cpu_tensor[:num_reqs],
3112-
num_reqs=num_reqs,
3113-
num_actual_tokens=total_num_scheduled_tokens,
3114-
max_query_len=max_num_scheduled_tokens,
3115-
block_table_tensor=dummy_block_table,
3116-
slot_mapping=dummy_slot_mapping,
3117-
causal=False,
3118-
)
3115+
group_metadata = dict[str, tuple[CommonAttentionMetadata, Any]]()
31193116

3120-
return common_metadata, builder.build(
3121-
common_prefix_len=0, # No cascade for encoder
3122-
common_attn_metadata=common_metadata,
3123-
)
3117+
for attn_group_list in self.attn_groups:
3118+
3119+
assert len(attn_group_list) == 1
3120+
attn_group = attn_group_list[0]
3121+
3122+
# Use the first attention metadata builder
3123+
# to create encoder attention metadata
3124+
builder = attn_group.metadata_builder
3125+
3126+
common_metadata = CommonAttentionMetadata(
3127+
query_start_loc=self.query_start_loc[:num_reqs + 1],
3128+
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
3129+
seq_lens=self.seq_lens[:num_reqs],
3130+
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
3131+
num_computed_tokens_cpu=self.input_batch.
3132+
num_computed_tokens_cpu_tensor[:num_reqs],
3133+
num_reqs=num_reqs,
3134+
num_actual_tokens=total_num_scheduled_tokens,
3135+
max_query_len=max_num_scheduled_tokens,
3136+
block_table_tensor=dummy_block_table,
3137+
slot_mapping=dummy_slot_mapping,
3138+
causal=False,
3139+
)
3140+
3141+
metadata = builder.build(
3142+
common_prefix_len=0, # No cascade for encoder
3143+
common_attn_metadata=common_metadata,
3144+
)
3145+
3146+
for layer_name in attn_group.layer_names:
3147+
group_metadata[layer_name] = (common_metadata, metadata)
3148+
3149+
return group_metadata

0 commit comments

Comments
 (0)