Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit ec26653

Browse files
authored
[Bugfix][VLM] Add fallback to SDPA for ViT model running on CPU backend (vllm-project#8061)
1 parent 0fbc669 commit ec26653

File tree

5 files changed

+157
-44
lines changed

5 files changed

+157
-44
lines changed

vllm/model_executor/models/blip.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch.nn as nn
88
from PIL import Image
99
from transformers import Blip2VisionConfig, BlipVisionConfig
10-
from xformers import ops as xops
10+
from transformers.models.blip.modeling_blip import BlipAttention
1111

1212
from vllm.config import ModelConfig
1313
from vllm.distributed import divide, get_tensor_model_parallel_world_size
@@ -21,6 +21,12 @@
2121
repeat_and_pad_placeholder_tokens)
2222
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
2323

24+
try:
25+
from xformers import ops as xops
26+
USE_XFORMERS_OPS = True
27+
except ImportError:
28+
USE_XFORMERS_OPS = False
29+
2430

2531
def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
2632
assert image_size % patch_size == 0
@@ -156,7 +162,7 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
156162
return embeddings
157163

158164

159-
class BlipAttention(nn.Module):
165+
class BlipParallelAttention(nn.Module):
160166
"""Multi-headed attention from 'Attention Is All You Need' paper"""
161167

162168
def __init__(
@@ -224,7 +230,7 @@ def forward(
224230
out = out.view(bsz, tgt_len, -1)
225231
attn_output, _ = self.projection(out)
226232

227-
return attn_output
233+
return attn_output, None
228234

229235

230236
class BlipMLP(nn.Module):
@@ -261,7 +267,16 @@ def __init__(self,
261267
quant_config: Optional[QuantizationConfig] = None):
262268
super().__init__()
263269

264-
self.self_attn = BlipAttention(config, quant_config=quant_config)
270+
# fallback to sdpa attention if tp unavailable
271+
num_heads = config.num_attention_heads
272+
tp_size = get_tensor_model_parallel_world_size()
273+
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
274+
self.self_attn = BlipParallelAttention(config,
275+
quant_config=quant_config)
276+
else:
277+
# Blip doesn't have SDPA attention implemented in transformers
278+
# use eager attention instead for cpu backend
279+
self.self_attn = BlipAttention(config)
265280
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
266281
eps=config.layer_norm_eps)
267282
self.mlp = BlipMLP(config, quant_config=quant_config)
@@ -272,7 +287,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
272287
residual = hidden_states
273288

274289
hidden_states = self.layer_norm1(hidden_states)
275-
hidden_states = self.self_attn(hidden_states=hidden_states)
290+
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
276291
hidden_states = residual + hidden_states
277292

278293
residual = hidden_states

vllm/model_executor/models/clip.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch.nn as nn
88
from PIL import Image
99
from transformers import CLIPVisionConfig
10-
from xformers import ops as xops
10+
from transformers.models.clip.modeling_clip import CLIPSdpaAttention
1111

1212
from vllm.config import ModelConfig
1313
from vllm.distributed import divide, get_tensor_model_parallel_world_size
@@ -22,6 +22,12 @@
2222
repeat_and_pad_placeholder_tokens)
2323
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
2424

25+
try:
26+
from xformers import ops as xops
27+
USE_XFORMERS_OPS = True
28+
except ImportError:
29+
USE_XFORMERS_OPS = False
30+
2531

2632
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
2733
assert image_size % patch_size == 0
@@ -162,7 +168,7 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
162168
return embeddings
163169

164170

165-
class CLIPAttention(nn.Module):
171+
class CLIPParallelAttention(nn.Module):
166172
"""Multi-headed attention from 'Attention Is All You Need' paper"""
167173

168174
def __init__(
@@ -231,7 +237,7 @@ def forward(
231237
out = out.view(bsz, tgt_len, -1)
232238
attn_output, _ = self.out_proj(out)
233239

234-
return attn_output
240+
return attn_output, None
235241

236242

237243
class CLIPMLP(nn.Module):
@@ -266,7 +272,13 @@ def __init__(self,
266272
quant_config: Optional[QuantizationConfig] = None):
267273
super().__init__()
268274

269-
self.self_attn = CLIPAttention(config, quant_config=quant_config)
275+
num_heads = config.num_attention_heads
276+
tp_size = get_tensor_model_parallel_world_size()
277+
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
278+
self.self_attn = CLIPParallelAttention(config,
279+
quant_config=quant_config)
280+
else:
281+
self.self_attn = CLIPSdpaAttention(config)
270282
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
271283
eps=config.layer_norm_eps)
272284
self.mlp = CLIPMLP(config, quant_config=quant_config)
@@ -278,7 +290,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
278290
residual = hidden_states
279291

280292
hidden_states = self.layer_norm1(hidden_states)
281-
hidden_states = self.self_attn(hidden_states=hidden_states)
293+
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
282294
hidden_states = residual + hidden_states
283295

284296
residual = hidden_states
@@ -365,6 +377,10 @@ def __init__(self,
365377
quant_config: Optional[QuantizationConfig] = None,
366378
num_hidden_layers_override: Optional[int] = None):
367379
super().__init__()
380+
tp_size = get_tensor_model_parallel_world_size()
381+
num_heads = config.num_attention_heads
382+
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
383+
368384
self.vision_model = CLIPVisionTransformer(
369385
config=config,
370386
quant_config=quant_config,
@@ -386,7 +402,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
386402
("qkv_proj", "q_proj", "q"),
387403
("qkv_proj", "k_proj", "k"),
388404
("qkv_proj", "v_proj", "v"),
389-
]
405+
] if self.shard_weight else []
390406
params_dict = dict(self.named_parameters())
391407
layer_count = len(self.vision_model.encoder.layers)
392408

vllm/model_executor/models/intern_vit.py

Lines changed: 70 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import torch.nn as nn
1111
import torch.nn.functional as F
1212
from transformers import PretrainedConfig
13-
from xformers import ops as xops
1413

1514
from vllm.distributed import divide, get_tensor_model_parallel_world_size
1615
from vllm.model_executor.layers.activation import get_act_fn
@@ -21,6 +20,12 @@
2120
from vllm.model_executor.layers.quantization import QuantizationConfig
2221
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
2322

23+
try:
24+
from xformers import ops as xops
25+
USE_XFORMERS_OPS = True
26+
except ImportError:
27+
USE_XFORMERS_OPS = False
28+
2429
NORM2FN = {
2530
'rms_norm': RMSNorm,
2631
'layer_norm': nn.LayerNorm,
@@ -81,7 +86,7 @@ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
8186
return embeddings
8287

8388

84-
class InternAttention(nn.Module):
89+
class InternParallelAttention(nn.Module):
8590
"""Multi-headed attention from 'Attention Is All You Need' paper"""
8691

8792
def __init__(
@@ -140,18 +145,67 @@ def forward(self, x):
140145
k = self.k_norm.forward_native(k.flatten(-2,
141146
-1)).view(B_, N_, H_, D_)
142147

143-
x = xops.memory_efficient_attention_forward(
144-
q,
145-
k,
146-
v,
147-
scale=self.scale,
148-
)
148+
x = xops.memory_efficient_attention_forward(q, k, v, scale=self.scale)
149149
x = x.view(B, N, -1)
150150

151151
x, _ = self.proj(x)
152152
return x
153153

154154

155+
class InternSdpaAttention(nn.Module):
156+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
157+
158+
def __init__(self, config: PretrainedConfig):
159+
super().__init__()
160+
self.config = config
161+
self.embed_dim = config.hidden_size
162+
self.num_heads = config.num_attention_heads
163+
self.head_dim = self.embed_dim // self.num_heads
164+
if self.head_dim * self.num_heads != self.embed_dim:
165+
raise ValueError(
166+
f'embed_dim must be divisible by num_heads '
167+
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
168+
f' {self.num_heads}).')
169+
170+
self.scale = self.head_dim**-0.5
171+
self.qkv = nn.Linear(self.embed_dim,
172+
3 * self.embed_dim,
173+
bias=config.qkv_bias)
174+
175+
self.qk_normalization = config.qk_normalization
176+
177+
if self.qk_normalization:
178+
self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
179+
self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
180+
181+
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
182+
183+
def forward(self, x):
184+
B, N, C = x.shape
185+
qkv = self.qkv(x)
186+
q, k, v = qkv.chunk(3, dim=-1)
187+
188+
q = q.view(B, N, self.num_heads, self.head_dim)
189+
k = k.view(B, N, self.num_heads, self.head_dim)
190+
v = v.view(B, N, self.num_heads, self.head_dim)
191+
192+
if self.qk_normalization:
193+
B_, N_, H_, D_ = q.shape
194+
q = self.q_norm.forward_native(q.flatten(-2,
195+
-1)).view(B_, N_, H_, D_)
196+
k = self.k_norm.forward_native(k.flatten(-2,
197+
-1)).view(B_, N_, H_, D_)
198+
q = q.transpose(1, 2)
199+
k = k.transpose(1, 2)
200+
v = v.transpose(1, 2)
201+
202+
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
203+
x = x.transpose(1, 2).view(B, N, -1)
204+
205+
x = self.proj(x)
206+
return x
207+
208+
155209
class InternMLP(nn.Module):
156210

157211
def __init__(self,
@@ -187,7 +241,14 @@ def __init__(self,
187241
self.intermediate_size = config.intermediate_size
188242
self.norm_type = config.norm_type
189243

190-
self.attn = InternAttention(config, quant_config=quant_config)
244+
# fallback to sdpa attention if tp unavailable
245+
tp_size = get_tensor_model_parallel_world_size()
246+
num_heads = config.num_attention_heads
247+
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
248+
self.attn = InternParallelAttention(config,
249+
quant_config=quant_config)
250+
else:
251+
self.attn = InternSdpaAttention(config)
191252
self.mlp = InternMLP(config, quant_config=quant_config)
192253
self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
193254
eps=config.layer_norm_eps)

vllm/model_executor/models/paligemma.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -307,26 +307,30 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
307307
if key_to_modify in name:
308308
name = name.replace(key_to_modify, new_key)
309309
use_default_weight_loading = False
310-
for (param_name, shard_name, shard_id) in stacked_params_mapping:
311-
if shard_name not in name:
312-
continue
313-
name = name.replace(shard_name, param_name)
314-
# Skip loading extra bias for GPTQ models.
315-
if name.endswith(".bias") and name not in params_dict:
316-
continue
317-
param = params_dict[name]
318-
weight_loader = param.weight_loader
319-
weight_loader(param, loaded_weight, shard_id)
320-
break
310+
if "vision" not in name or self.vision_tower.shard_weight:
311+
for (param_name, shard_name,
312+
shard_id) in stacked_params_mapping:
313+
if shard_name not in name:
314+
continue
315+
name = name.replace(shard_name, param_name)
316+
# Skip loading extra bias for GPTQ models.
317+
if name.endswith(".bias") and name not in params_dict:
318+
continue
319+
param = params_dict[name]
320+
weight_loader = param.weight_loader
321+
weight_loader(param, loaded_weight, shard_id)
322+
break
323+
else:
324+
# lm_head is not used in vllm as it is tied with
325+
# embed_token. To prevent errors, skip loading
326+
# lm_head.weight.
327+
if "lm_head.weight" in name:
328+
continue
329+
# Skip loading extra bias for GPTQ models.
330+
if name.endswith(".bias") and name not in params_dict:
331+
continue
332+
use_default_weight_loading = True
321333
else:
322-
# lm_head is not used in vllm as it is tied with
323-
# embed_token. To prevent errors, skip loading
324-
# lm_head.weight.
325-
if "lm_head.weight" in name:
326-
continue
327-
# Skip loading extra bias for GPTQ models.
328-
if name.endswith(".bias") and name not in params_dict:
329-
continue
330334
use_default_weight_loading = True
331335

332336
if use_default_weight_loading:

vllm/model_executor/models/siglip.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from PIL import Image
1010
from torch import nn
1111
from transformers import SiglipVisionConfig
12-
from xformers import ops as xops
12+
from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention
1313

1414
from vllm.config import ModelConfig
1515
from vllm.distributed import divide, get_tensor_model_parallel_world_size
@@ -26,6 +26,12 @@
2626
repeat_and_pad_placeholder_tokens)
2727
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
2828

29+
try:
30+
from xformers import ops as xops
31+
USE_XFORMERS_OPS = True
32+
except ImportError:
33+
USE_XFORMERS_OPS = False
34+
2935

3036
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
3137
# Since interpolation is applied, the image size need not be divisible
@@ -219,7 +225,7 @@ def forward(self,
219225
return embeddings
220226

221227

222-
class SiglipAttention(nn.Module):
228+
class SiglipParallelAttention(nn.Module):
223229

224230
def __init__(
225231
self,
@@ -282,7 +288,7 @@ def forward(
282288
out = out.view(batch_size, q_len, -1)
283289
attn_output, _ = self.out_proj(out)
284290

285-
return attn_output
291+
return attn_output, None
286292

287293

288294
class SiglipMLP(nn.Module):
@@ -327,7 +333,14 @@ def __init__(
327333
super().__init__()
328334
self.embed_dim = config.hidden_size
329335

330-
self.self_attn = SiglipAttention(config, quant_config=quant_config)
336+
num_heads = config.num_attention_heads
337+
tp_size = get_tensor_model_parallel_world_size()
338+
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
339+
self.self_attn = SiglipParallelAttention(config,
340+
quant_config=quant_config)
341+
else:
342+
self.self_attn = SiglipSdpaAttention(config)
343+
331344
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
332345
eps=config.layer_norm_eps)
333346
self.mlp = SiglipMLP(
@@ -344,7 +357,7 @@ def forward(
344357
residual = hidden_states
345358

346359
hidden_states = self.layer_norm1(hidden_states)
347-
hidden_states = self.self_attn(hidden_states=hidden_states)
360+
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
348361
hidden_states = residual + hidden_states
349362

350363
residual = hidden_states
@@ -476,6 +489,10 @@ def __init__(
476489
num_hidden_layers_override: Optional[int] = None,
477490
):
478491
super().__init__()
492+
num_heads = config.num_attention_heads
493+
tp_size = get_tensor_model_parallel_world_size()
494+
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
495+
479496
self.vision_model = SiglipVisionTransformer(
480497
config,
481498
quant_config,

0 commit comments

Comments
 (0)