Skip to content

Commit 51c51a4

Browse files
noooopmaxdebayser
authored andcommitted
[Bugfix] Fix ModernBert load & Enable sliding window attention for bidirectional attention. (vllm-project#22637)
Signed-off-by: wang.yuqi <[email protected]> Signed-off-by: Max de Bayser <[email protected]> Co-authored-by: Max de Bayser <[email protected]> Signed-off-by: Paul Pak <[email protected]>
1 parent 0a1bc4b commit 51c51a4

File tree

4 files changed

+101
-59
lines changed

4 files changed

+101
-59
lines changed

tests/models/language/pooling/test_gte.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44

55
import pytest
66

7-
from ...utils import (CLSPoolingEmbedModelInfo, EmbedModelInfo,
8-
LASTPoolingEmbedModelInfo, check_transformers_version)
7+
from ...utils import (CLSPoolingEmbedModelInfo, CLSPoolingRerankModelInfo,
8+
EmbedModelInfo, LASTPoolingEmbedModelInfo,
9+
RerankModelInfo, check_transformers_version)
910
from .embed_utils import correctness_test_embed_models
10-
from .mteb_utils import mteb_test_embed_models
11+
from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models
1112

1213
MODELS = [
1314
########## BertModel
@@ -58,6 +59,14 @@
5859
enable_test=False),
5960
]
6061

62+
RERANK_MODELS = [
63+
# classifier_pooling: mean
64+
CLSPoolingRerankModelInfo(
65+
"Alibaba-NLP/gte-reranker-modernbert-base",
66+
architecture="ModernBertForSequenceClassification",
67+
enable_test=True),
68+
]
69+
6170

6271
@pytest.mark.parametrize("model_info", MODELS)
6372
def test_embed_models_mteb(hf_runner, vllm_runner,
@@ -88,3 +97,9 @@ def test_embed_models_correctness(hf_runner, vllm_runner,
8897

8998
correctness_test_embed_models(hf_runner, vllm_runner, model_info,
9099
example_prompts, vllm_extra_kwargs)
100+
101+
102+
@pytest.mark.parametrize("model_info", RERANK_MODELS)
103+
def test_rerank_models_mteb(hf_runner, vllm_runner,
104+
model_info: RerankModelInfo) -> None:
105+
mteb_test_rerank_models(hf_runner, vllm_runner, model_info)

vllm/model_executor/models/modernbert.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@
2626
from vllm.sequence import IntermediateTensors
2727
from vllm.tasks import PoolingTask
2828

29-
from .interfaces import (SupportsCrossEncoding, SupportsV0Only,
30-
default_pooling_type)
29+
from .interfaces import SupportsCrossEncoding, default_pooling_type
3130
from .utils import WeightsMapper, maybe_prefix
3231

3332

@@ -93,16 +92,14 @@ def __init__(self,
9392
bias=config.attention_bias,
9493
)
9594

95+
sliding_window = None
9696
if layer_id % config.global_attn_every_n_layers != 0:
97-
self.local_attention = (config.local_attention // 2,
98-
config.local_attention // 2)
97+
sliding_window = config.local_attention // 2
98+
rope_theta = config.local_rope_theta if config.local_rope_theta \
99+
is not None else config.global_rope_theta
99100
else:
100-
self.local_attention = (-1, -1)
101+
rope_theta = config.global_rope_theta
101102

102-
rope_theta = config.global_rope_theta
103-
if self.local_attention != (
104-
-1, -1) and config.local_rope_theta is not None:
105-
rope_theta = config.local_rope_theta
106103
self.rotary_emb = ModernBertRotaryEmbedding(config=config,
107104
head_size=self.head_dim,
108105
dim=self.head_dim,
@@ -111,7 +108,8 @@ def __init__(self,
111108
self.head_dim,
112109
self.scaling,
113110
prefix=f"{layer_id}.attn",
114-
attn_type=AttentionType.ENCODER_ONLY)
111+
attn_type=AttentionType.ENCODER_ONLY,
112+
per_layer_sliding_window=sliding_window)
115113
self.Wo = RowParallelLinear(config.hidden_size,
116114
config.hidden_size,
117115
bias=config.attention_bias)
@@ -278,6 +276,7 @@ def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
278276
return self.pooling.get_pooling_updates(task)
279277

280278
def _head(self, pooled_output: torch.Tensor):
279+
pooled_output = pooled_output.to(self.dense.weight.dtype)
281280
return self.norm(self.act(self.dense(pooled_output)))
282281

283282
def forward(
@@ -296,8 +295,7 @@ def forward(
296295

297296

298297
@default_pooling_type("CLS")
299-
class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
300-
SupportsCrossEncoding):
298+
class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
301299

302300
is_pooling_model = True
303301

@@ -308,6 +306,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
308306
self.model = ModernBertModel(vllm_config=vllm_config,
309307
prefix=maybe_prefix(prefix, "modernbert"))
310308
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
309+
self.pooling = ModernBertPooler(config)
311310

312311
pooler_config = vllm_config.model_config.pooler_config
313312
assert pooler_config is not None
@@ -317,14 +316,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
317316
Pooler.for_encode(pooler_config),
318317
"classify":
319318
ClassifierPooler(
320-
pooling=ModernBertPooler(config),
319+
pooling=self.pooling,
321320
classifier=self.classifier,
322321
act_fn=ClassifierPooler.act_fn_for_seq_cls(
323322
vllm_config.model_config),
324323
),
325324
"score":
326325
ClassifierPooler(
327-
pooling=ModernBertPooler(config),
326+
pooling=self.pooling,
328327
classifier=self.classifier,
329328
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
330329
vllm_config.model_config),
@@ -353,7 +352,7 @@ def weight_filter():
353352
default_weight_loader)
354353
weight_loader(param, loaded_weight)
355354
if name.startswith("head"):
356-
param = params_dict["_pooler.pooler." + name[len("head") + 1:]]
355+
param = params_dict["pooling." + name[len("head") + 1:]]
357356
weight_loader = getattr(param, "weight_loader",
358357
default_weight_loader)
359358
weight_loader(param, loaded_weight)
@@ -368,5 +367,5 @@ def forward(
368367
return self.model(
369368
input_ids=input_ids,
370369
inputs_embeds=inputs_embeds,
371-
position_ids=positions,
370+
positions=positions,
372371
)

vllm/v1/attention/backends/flash_attn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,8 @@ def __init__(
384384
self.alibi_slopes = alibi_slopes
385385
if sliding_window is None:
386386
self.sliding_window = (-1, -1)
387+
elif attn_type == AttentionType.ENCODER_ONLY:
388+
self.sliding_window = (sliding_window - 1, sliding_window - 1)
387389
else:
388390
self.sliding_window = (sliding_window - 1, 0)
389391
self.kv_cache_dtype = kv_cache_dtype

vllm/v1/worker/gpu_model_runner.py

Lines changed: 66 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -827,7 +827,8 @@ def _prepare_inputs(
827827
# Prepare encoder attention metadata separately
828828
# (encoder layers are not in KV cache groups)
829829
if self.is_encoder_only_model:
830-
common_attn_metadata, encoder_attn_metadata = \
830+
831+
per_layer_metadata = \
831832
self._build_encoder_only_attn_metadata(
832833
scheduler_output)
833834

@@ -836,6 +837,8 @@ def _prepare_inputs(
836837
self.vllm_config, Attention)
837838
for layer_name, attn_module in attention_layers.items():
838839
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
840+
common_attn_metadata, encoder_attn_metadata =\
841+
per_layer_metadata[layer_name]
839842
attn_metadata[layer_name] = encoder_attn_metadata
840843

841844
# Prepare the attention metadata for each KV cache group and make layers
@@ -2684,30 +2687,41 @@ def create_attn_groups(
26842687
# Check if model is encoder-only
26852688
block_size = self.vllm_config.cache_config.block_size
26862689
use_mla = self.vllm_config.model_config.use_mla
2687-
attn_specs = list[AttentionSpec]()
2688-
for attn_module in attn_layers.values():
2690+
attn_specs: dict[AttentionSpec, list[str]] = defaultdict(list)
2691+
for layer_name, attn_module in attn_layers.items():
26892692

26902693
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
2691-
assert attn_module.sliding_window is None, "Sliding "
2692-
"window attention is not supported for encoder-only models"
2693-
2694-
attn_specs.append(
2695-
FullAttentionSpec(block_size=block_size,
2696-
num_kv_heads=attn_module.num_kv_heads,
2697-
head_size=attn_module.head_size,
2698-
dtype=self.kv_cache_dtype,
2699-
use_mla=use_mla))
2694+
if attn_module.sliding_window is None:
2695+
attn_spec: AttentionSpec = FullAttentionSpec(
2696+
block_size=block_size,
2697+
num_kv_heads=attn_module.num_kv_heads,
2698+
head_size=attn_module.head_size,
2699+
dtype=self.kv_cache_dtype,
2700+
use_mla=use_mla)
2701+
else:
2702+
attn_spec = SlidingWindowSpec(
2703+
block_size=block_size,
2704+
num_kv_heads=attn_module.num_kv_heads,
2705+
head_size=attn_module.head_size,
2706+
dtype=self.kv_cache_dtype,
2707+
sliding_window=attn_module.sliding_window,
2708+
use_mla=use_mla)
2709+
attn_specs[attn_spec].append(layer_name)
2710+
27002711
else:
27012712
raise ValueError("Expected only encoder-only layers")
27022713

27032714
if len(attn_specs) > 0:
2704-
assert len(attn_specs) == len(attn_layers), \
2705-
"All or none of the layers are expected to be encoder-only"
2715+
total_layers = 0
2716+
for attn_spec, layer_names in attn_specs.items():
27062717

2707-
attn_backends = get_attn_backends_for_layers(attn_layers.keys())
2718+
attn_backends = get_attn_backends_for_layers(layer_names)
2719+
total_layers += len(layer_names)
27082720

2709-
self.attn_groups.append(
2710-
create_attn_groups(attn_backends, attn_specs[0]))
2721+
self.attn_groups.append(
2722+
create_attn_groups(attn_backends, attn_spec))
2723+
assert total_layers == len(attn_layers), \
2724+
"All or none of the layers are expected to be encoder-only"
27112725
self.is_encoder_only_model = True
27122726

27132727
def calculate_reorder_batch_threshold(self) -> None:
@@ -3080,7 +3094,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
30803094

30813095
def _build_encoder_only_attn_metadata(
30823096
self, scheduler_output: "SchedulerOutput") -> \
3083-
tuple[CommonAttentionMetadata, Any]:
3097+
dict[str, tuple[CommonAttentionMetadata, Any]]:
30843098
"""Prepare encoder attention metadata for encoder-only models.
30853099
30863100
Args:
@@ -3097,33 +3111,45 @@ def _build_encoder_only_attn_metadata(
30973111
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
30983112
max_num_scheduled_tokens = max(tokens)
30993113

3100-
# Use the first attention metadata builder
3101-
# to create encoder attention metadata
3102-
builder = self.attn_groups[0][0].metadata_builder
3103-
31043114
dummy_block_table = torch.zeros((num_reqs, 1),
31053115
dtype=torch.int32,
31063116
device=self.device)
31073117
dummy_slot_mapping = torch.zeros((total_num_scheduled_tokens, ),
31083118
dtype=torch.int32,
31093119
device=self.device)
31103120

3111-
common_metadata = CommonAttentionMetadata(
3112-
query_start_loc=self.query_start_loc[:num_reqs + 1],
3113-
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
3114-
seq_lens=self.seq_lens[:num_reqs],
3115-
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
3116-
num_computed_tokens_cpu=self.input_batch.
3117-
num_computed_tokens_cpu_tensor[:num_reqs],
3118-
num_reqs=num_reqs,
3119-
num_actual_tokens=total_num_scheduled_tokens,
3120-
max_query_len=max_num_scheduled_tokens,
3121-
block_table_tensor=dummy_block_table,
3122-
slot_mapping=dummy_slot_mapping,
3123-
causal=False,
3124-
)
3121+
group_metadata = dict[str, tuple[CommonAttentionMetadata, Any]]()
31253122

3126-
return common_metadata, builder.build(
3127-
common_prefix_len=0, # No cascade for encoder
3128-
common_attn_metadata=common_metadata,
3129-
)
3123+
for attn_group_list in self.attn_groups:
3124+
3125+
assert len(attn_group_list) == 1
3126+
attn_group = attn_group_list[0]
3127+
3128+
# Use the first attention metadata builder
3129+
# to create encoder attention metadata
3130+
builder = attn_group.metadata_builder
3131+
3132+
common_metadata = CommonAttentionMetadata(
3133+
query_start_loc=self.query_start_loc[:num_reqs + 1],
3134+
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
3135+
seq_lens=self.seq_lens[:num_reqs],
3136+
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
3137+
num_computed_tokens_cpu=self.input_batch.
3138+
num_computed_tokens_cpu_tensor[:num_reqs],
3139+
num_reqs=num_reqs,
3140+
num_actual_tokens=total_num_scheduled_tokens,
3141+
max_query_len=max_num_scheduled_tokens,
3142+
block_table_tensor=dummy_block_table,
3143+
slot_mapping=dummy_slot_mapping,
3144+
causal=False,
3145+
)
3146+
3147+
metadata = builder.build(
3148+
common_prefix_len=0, # No cascade for encoder
3149+
common_attn_metadata=common_metadata,
3150+
)
3151+
3152+
for layer_name in attn_group.layer_names:
3153+
group_metadata[layer_name] = (common_metadata, metadata)
3154+
3155+
return group_metadata

0 commit comments

Comments
 (0)