Skip to content

Commit cbbbc3a

Browse files
authored
Qualcomm AI Engine Direct - Fix the regression of whisper model (#13062)
Summary: - Resolve the Whisper model accuracy issue caused by upgrading the Transformers. - Modify decompose_sdpa.py to support kwargs "scale" - fixed internal CI cc: @haowhsu-quic , @winskuo-quic
1 parent c1dba0f commit cbbbc3a

File tree

5 files changed

+63
-33
lines changed

5 files changed

+63
-33
lines changed

backends/qualcomm/tests/models.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1530,12 +1530,13 @@ def forward(self, x):
15301530

15311531

15321532
class ScaledDotProductAttention(torch.nn.Module):
1533-
def __init__(self):
1533+
def __init__(self, scale=None):
15341534
super().__init__()
1535+
self.scale = scale
15351536

15361537
def forward(self, query_layer, key_layer, value_layer, attn_mask):
15371538
attn_output = torch.nn.functional.scaled_dot_product_attention(
1538-
query_layer, key_layer, value_layer, attn_mask
1539+
query_layer, key_layer, value_layer, attn_mask, scale=self.scale
15391540
)
15401541
return attn_output
15411542

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,7 +1008,11 @@ def test_qnn_backend_rsqrt(self):
10081008
self.lower_module_and_test_output(module, sample_input)
10091009

10101010
def test_qnn_backend_sdpa(self):
1011-
module = ScaledDotProductAttention() # noqa: F405
1011+
modules = [
1012+
ScaledDotProductAttention(), # noqa: F405
1013+
ScaledDotProductAttention(scale=0.5), # noqa: F405
1014+
ScaledDotProductAttention(scale=1.0), # noqa: F405
1015+
]
10121016
mask = torch.tril(torch.randn(1, 1, 100, 100))
10131017
mask[mask == 0] = float("-inf")
10141018
sample_input = (
@@ -1017,7 +1021,9 @@ def test_qnn_backend_sdpa(self):
10171021
torch.randn(1, 4, 100, 64),
10181022
mask,
10191023
)
1020-
self.lower_module_and_test_output(module, sample_input)
1024+
for i, module in enumerate(modules):
1025+
with self.subTest(i=i):
1026+
self.lower_module_and_test_output(module, sample_input)
10211027

10221028
def test_qnn_backend_sigmoid(self):
10231029
module = Sigmoid() # noqa: F405
@@ -2414,7 +2420,11 @@ def test_qnn_backend_rsqrt(self):
24142420
self.lower_module_and_test_output(module, sample_input)
24152421

24162422
def test_qnn_backend_sdpa(self):
2417-
module = ScaledDotProductAttention() # noqa: F405
2423+
modules = [
2424+
ScaledDotProductAttention(), # noqa: F405
2425+
ScaledDotProductAttention(scale=0.5), # noqa: F405
2426+
ScaledDotProductAttention(scale=1.0), # noqa: F405
2427+
]
24182428
mask = torch.tril(torch.randn(1, 1, 100, 100))
24192429
mask[mask == 0] = torch.finfo(torch.float32).min
24202430
sample_input = (
@@ -2423,8 +2433,12 @@ def test_qnn_backend_sdpa(self):
24232433
torch.randn(1, 4, 100, 64),
24242434
mask,
24252435
)
2426-
module = self.get_qdq_module(module, sample_input)
2427-
self.lower_module_and_test_output(module, sample_input)
2436+
for i, module in enumerate(modules):
2437+
with self.subTest(i=i):
2438+
module = self.get_qdq_module(
2439+
module, sample_input, quant_dtype=QuantDtype.use_16a8w
2440+
)
2441+
self.lower_module_and_test_output(module, sample_input)
24282442

24292443
def test_qnn_backend_select_copy(self):
24302444
module = SelectCopy() # noqa: F405
@@ -4951,13 +4965,14 @@ def test_gMLP(self):
49514965
self.assertGreaterEqual(msg["top_1"], 60)
49524966
self.assertGreaterEqual(msg["top_5"], 85)
49534967

4954-
def test_mobilevit_v1(self):
4968+
@unittest.skip("Only outputs good accuracy in QNN 2.29")
4969+
def test_mobilevit_v2(self):
49554970
if not self.required_envs([self.image_dataset]):
49564971
self.skipTest("missing required envs")
49574972

49584973
cmds = [
49594974
"python",
4960-
f"{self.executorch_root}/examples/qualcomm/oss_scripts/mobilevit_v1.py"
4975+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/mobilevit_v2.py",
49614976
"--dataset",
49624977
self.image_dataset,
49634978
"--artifact",
@@ -4975,6 +4990,8 @@ def test_mobilevit_v1(self):
49754990
]
49764991
if self.host:
49774992
cmds.extend(["--host", self.host])
4993+
if self.shared_buffer:
4994+
cmds.extend(["--shared_buffer"])
49784995

49794996
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
49804997
with Listener((self.ip, self.port)) as listener:
@@ -4984,17 +5001,16 @@ def test_mobilevit_v1(self):
49845001
if "Error" in msg:
49855002
self.fail(msg["Error"])
49865003
else:
4987-
self.assertGreaterEqual(msg["top_1"], 70)
5004+
self.assertGreaterEqual(msg["top_1"], 50)
49885005
self.assertGreaterEqual(msg["top_5"], 85)
49895006

4990-
@unittest.skip("Only outputs good accuracy in QNN 2.29")
4991-
def test_mobilevit_v2(self):
5007+
def test_mobilevit1(self):
49925008
if not self.required_envs([self.image_dataset]):
49935009
self.skipTest("missing required envs")
49945010

49955011
cmds = [
49965012
"python",
4997-
f"{self.executorch_root}/examples/qualcomm/oss_scripts/mobilevit_v2.py",
5013+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/mobilevit1.py",
49985014
"--dataset",
49995015
self.image_dataset,
50005016
"--artifact",
@@ -5012,8 +5028,6 @@ def test_mobilevit_v2(self):
50125028
]
50135029
if self.host:
50145030
cmds.extend(["--host", self.host])
5015-
if self.shared_buffer:
5016-
cmds.extend(["--shared_buffer"])
50175031

50185032
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
50195033
with Listener((self.ip, self.port)) as listener:
@@ -5023,7 +5037,7 @@ def test_mobilevit_v2(self):
50235037
if "Error" in msg:
50245038
self.fail(msg["Error"])
50255039
else:
5026-
self.assertGreaterEqual(msg["top_1"], 50)
5040+
self.assertGreaterEqual(msg["top_1"], 70)
50275041
self.assertGreaterEqual(msg["top_5"], 85)
50285042

50295043
def test_pvt(self):
@@ -5033,7 +5047,11 @@ def test_pvt(self):
50335047
cmds = [
50345048
"python",
50355049
f"{self.executorch_root}/examples/qualcomm/oss_scripts/pvt.py",
5050+
"--dataset",
50365051
self.image_dataset,
5052+
"--artifact",
5053+
self.artifact_dir,
5054+
"--build_folder",
50375055
self.build_folder,
50385056
"--device",
50395057
self.device,

backends/transforms/decompose_sdpa.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
# pyre-strict
88

9+
import math
10+
911
import torch
1012
from executorch.exir.pass_base import ExportPass, PassResult
1113
from torch._decomp import get_decompositions
@@ -30,6 +32,7 @@ def call(
3032
for node in graph.nodes:
3133
if node.target == torch.ops.aten.scaled_dot_product_attention.default:
3234
input_tensors = (arg.meta["val"] for arg in node.args)
35+
scale = node.kwargs.get("scale", None)
3336

3437
# refer to pytorch/test/test_decomp.py
3538
decomposed_module = make_fx(
@@ -81,6 +84,16 @@ def call(
8184
)
8285
continue
8386

87+
if scale is not None and decomposed_node.target in [
88+
torch.ops.aten.mul.Scalar
89+
]:
90+
new_args = list(decomposed_node.args)
91+
# Based on the implementation of _scaled_dot_product_attention_math,
92+
# the scale is applied to q and k before matmul.
93+
# refer to pytorch/aten/src/ATen/native/transformers/attention.cpp#L873
94+
new_args[1] = math.sqrt(scale)
95+
decomposed_node.args = tuple(new_args)
96+
8497
subgraph_node = graph.node_copy(
8598
decomposed_node,
8699
arg_transform=lambda x: decomposed_node_to_subgraph_node[ # noqa: B023

examples/qualcomm/oss_scripts/whisper/whisper.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@
3636

3737
from executorch.devtools.backend_debug import print_delegation_info
3838
from executorch.examples.qualcomm.oss_scripts.whisper.whisper_model import (
39-
Seq2SeqLMDecoderExportableModuleWithStaticCache,
40-
Seq2SeqLMEncoderExportableModule,
39+
QnnSeq2SeqLMDecoderExportableModuleWithStaticCache,
40+
QnnSeq2SeqLMEncoderExportableModule,
4141
)
4242

4343
from executorch.examples.qualcomm.utils import (
@@ -169,14 +169,14 @@ def __init__(
169169
)
170170

171171
self.whisper_encoder = (
172-
Seq2SeqLMEncoderExportableModule(whisper_model.get_encoder())
172+
QnnSeq2SeqLMEncoderExportableModule(whisper_model.get_encoder())
173173
.to("cpu")
174174
.eval()
175175
)
176176
self.encoder_passes_job = get_capture_program_passes()
177177

178178
self.whisper_decoder = (
179-
Seq2SeqLMDecoderExportableModuleWithStaticCache(
179+
QnnSeq2SeqLMDecoderExportableModuleWithStaticCache(
180180
whisper_model=whisper_model,
181181
max_cache_length=self.max_seq_length,
182182
batch_size=batch_size,
@@ -190,20 +190,21 @@ def __init__(
190190
self.exported_whisper_encoder = None
191191
self.exported_whisper_decoder = None
192192
self.has_quant_io = False
193+
self.kv_shape = {
194+
(self.max_seq_length, self.head_dim),
195+
}
193196

194197
def _tag_ios(self, node, fixed_point_type):
195198
if not self.has_quant_io:
196199
return
197200

198201
quant_io_type = None
199-
if node.op == "placeholder" and "static_cache_" in node.name:
202+
if node.op == "placeholder" and node.meta["val"].size()[-2:] in self.kv_shape:
200203
quant_io_type = fixed_point_type
201204

202205
if is_graph_output(node):
203206
# shape of k caches and v caches
204-
if node.meta["val"].size()[-2:] in {
205-
(self.max_seq_length, self.head_dim),
206-
}:
207+
if node.meta["val"].size()[-2:] in self.kv_shape:
207208
quant_io_type = fixed_point_type
208209

209210
return quant_io_type

examples/qualcomm/oss_scripts/whisper/whisper_model.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66

77

88
import torch
9-
from transformers import StaticCache, WhisperForConditionalGeneration
9+
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, StaticCache
10+
from transformers.models.whisper.modeling_whisper import WhisperForConditionalGeneration
1011

1112

12-
class Seq2SeqLMEncoderExportableModule(torch.nn.Module):
13+
class QnnSeq2SeqLMEncoderExportableModule(torch.nn.Module):
1314
"""
1415
A wrapper module designed to make a Seq2Seq LM encoder exportable with `torch.export`.
1516
This module ensures that the exported encoder model is compatible with ExecuTorch.
@@ -29,7 +30,7 @@ def get_metadata(self):
2930
return {}
3031

3132

32-
class Seq2SeqLMDecoderExportableModuleWithStaticCache(torch.nn.Module):
33+
class QnnSeq2SeqLMDecoderExportableModuleWithStaticCache(torch.nn.Module):
3334
"""
3435
A wrapper module designed to make a Seq2Seq LM decoder exportable with `torch.export`,
3536
specifically for use with static caching. This module ensures the exported decoder
@@ -57,11 +58,7 @@ def __init__(self, whisper_model, max_cache_length, batch_size):
5758
device="cpu",
5859
dtype=torch.float32,
5960
)
60-
61-
# Register cache buffers to make them exportable
62-
for i in range(len(self.static_cache.key_cache)):
63-
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i])
64-
self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i])
61+
self.cache = EncoderDecoderCache(self.static_cache, DynamicCache())
6562

6663
def forward(
6764
self, decoder_input_ids, attention_mask, encoder_hidden_states, cache_position
@@ -71,7 +68,7 @@ def forward(
7168
input_ids=decoder_input_ids,
7269
attention_mask=attention_mask,
7370
encoder_hidden_states=encoder_hidden_states,
74-
past_key_values=self.static_cache,
71+
past_key_values=self.cache,
7572
use_cache=True,
7673
cache_position=cache_position,
7774
)

0 commit comments

Comments
 (0)