Skip to content

Commit bb467e4

Browse files
committed
coreml: fix Whisper to CoreML conversion by disabling SDPA [no ci]
This commit disables the use of PyTorch's `scaled_dot_product_attention` in the Whisper model to avoid compatibility issues during CoreML conversion. The issue occurs because coremltools requires PyTorch 2.5.0, but the Whisper implementation may expect behavior from newer PyTorch versions. By setting `MultiHeadAttention.use_sdpa = False`, we force Whisper to use its fallback manual attention implementation, which works correctly with PyTorch 2.5.0 during the tracing process. Refs: #2783
1 parent 04b9508 commit bb467e4

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

models/convert-whisper-to-coreml.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,15 @@
1212
from whisper.model import Whisper, AudioEncoder, TextDecoder, ResidualAttentionBlock, MultiHeadAttention, ModelDimensions
1313
from whisper import load_model
1414

15+
# Disable PyTorch Scaled Dot-Product Attention (SDPA) to avoid compatibility issues.
16+
# The Whisper implementation expects a specific behavior from
17+
# torch.nn.functional.scaled_dot_product_attention that differs between PyTorch
18+
# versions. Setting use_sdpa=False forces Whisper to use its manual attention
19+
# implementation instead, which is more stable across different PyTorch versions
20+
# (2.5.0 required by coremltools vs newer versions).
21+
import whisper.model
22+
whisper.model.MultiHeadAttention.use_sdpa = False
23+
1524
# Use for changing dim of input in encoder and decoder embeddings
1625
def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict,
1726
missing_keys, unexpected_keys, error_msgs):

0 commit comments

Comments
 (0)