Skip to content

Commit bfa2c0b

Browse files
[ROCm][Bugfix] Fix RuntimeError in MMEncoderAttention by replacing .view() with .reshape() (vllm-project#31203)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
1 parent f790068 commit bfa2c0b

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

tests/models/multimodal/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def pytest_collection_modifyitems(config, items):
1919
return
2020

2121
# Disable Flash/MemEfficient SDP on ROCm to avoid HF Transformers
22-
# accuracy issues
22+
# accuracy issues: https://github.com/vllm-project/vllm/issues/30167
2323
# TODO: Remove once ROCm SDP accuracy issues are resolved on HuggingFace
2424
torch.backends.cuda.enable_flash_sdp(False)
2525
torch.backends.cuda.enable_mem_efficient_sdp(False)

vllm/attention/layers/mm_encoder_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def _forward_sdpa(
136136
cu_seqlens=cu_seqlens,
137137
)
138138
if is_reshaped:
139-
output = output.view(bsz, q_len, -1)
139+
output = output.reshape(bsz, q_len, -1)
140140
return output
141141

142142
def _forward_fa(
@@ -174,7 +174,7 @@ def _forward_fa(
174174
fa_version=self._fa_version,
175175
)
176176
if is_reshaped:
177-
output = output.view(bsz, q_len, -1)
177+
output = output.reshape(bsz, q_len, -1)
178178
return output
179179

180180
def forward_native(

0 commit comments

Comments
 (0)