Skip to content

Commit 936997d

Browse files
committed
Qualcomm AI Engine Direct - Fix the regression of whisper model
Summary: - Resolve the Whisper model accuracy issue caused by upgrading the Transformers. - Modify decompose_sdpa.py to support kwargs "scale" - fixed internal CI
1 parent 4197fc1 commit 936997d

File tree

5 files changed

+63
-33
lines changed

5 files changed

+63
-33
lines changed

backends/qualcomm/tests/models.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1530,12 +1530,13 @@ def forward(self, x):
15301530

15311531

15321532
class ScaledDotProductAttention(torch.nn.Module):
1533-
def __init__(self):
1533+
def __init__(self, scale=None):
15341534
super().__init__()
1535+
self.scale = scale
15351536

15361537
def forward(self, query_layer, key_layer, value_layer, attn_mask):
15371538
attn_output = torch.nn.functional.scaled_dot_product_attention(
1538-
query_layer, key_layer, value_layer, attn_mask
1539+
query_layer, key_layer, value_layer, attn_mask, scale=self.scale
15391540
)
15401541
return attn_output
15411542

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,7 +1008,11 @@ def test_qnn_backend_rsqrt(self):
10081008
self.lower_module_and_test_output(module, sample_input)
10091009

10101010
def test_qnn_backend_sdpa(self):
1011-
module = ScaledDotProductAttention() # noqa: F405
1011+
modules = [
1012+
ScaledDotProductAttention(), # noqa: F405
1013+
ScaledDotProductAttention(scale=0.5), # noqa: F405
1014+
ScaledDotProductAttention(scale=1.0), # noqa: F405
1015+
]
10121016
mask = torch.tril(torch.randn(1, 1, 100, 100))
10131017
mask[mask == 0] = float("-inf")
10141018
sample_input = (
@@ -1017,7 +1021,9 @@ def test_qnn_backend_sdpa(self):
10171021
torch.randn(1, 4, 100, 64),
10181022
mask,
10191023
)
1020-
self.lower_module_and_test_output(module, sample_input)
1024+
for i, module in enumerate(modules):
1025+
with self.subTest(i=i):
1026+
self.lower_module_and_test_output(module, sample_input)
10211027

10221028
def test_qnn_backend_sigmoid(self):
10231029
module = Sigmoid() # noqa: F405
@@ -2414,7 +2420,11 @@ def test_qnn_backend_rsqrt(self):
24142420
self.lower_module_and_test_output(module, sample_input)
24152421

24162422
def test_qnn_backend_sdpa(self):
2417-
module = ScaledDotProductAttention() # noqa: F405
2423+
modules = [
2424+
ScaledDotProductAttention(), # noqa: F405
2425+
ScaledDotProductAttention(scale=0.5), # noqa: F405
2426+
ScaledDotProductAttention(scale=1.0), # noqa: F405
2427+
]
24182428
mask = torch.tril(torch.randn(1, 1, 100, 100))
24192429
mask[mask == 0] = torch.finfo(torch.float32).min
24202430
sample_input = (
@@ -2423,8 +2433,12 @@ def test_qnn_backend_sdpa(self):
24232433
torch.randn(1, 4, 100, 64),
24242434
mask,
24252435
)
2426-
module = self.get_qdq_module(module, sample_input)
2427-
self.lower_module_and_test_output(module, sample_input)
2436+
for i, module in enumerate(modules):
2437+
with self.subTest(i=i):
2438+
module = self.get_qdq_module(
2439+
module, sample_input, quant_dtype=QuantDtype.use_16a8w
2440+
)
2441+
self.lower_module_and_test_output(module, sample_input)
24282442

24292443
def test_qnn_backend_select_copy(self):
24302444
module = SelectCopy() # noqa: F405
@@ -4949,13 +4963,14 @@ def test_gMLP(self):
49494963
self.assertGreaterEqual(msg["top_1"], 60)
49504964
self.assertGreaterEqual(msg["top_5"], 85)
49514965

4952-
def test_mobilevit_v1(self):
4966+
@unittest.skip("Only outputs good accuracy in QNN 2.29")
4967+
def test_mobilevit_v2(self):
49534968
if not self.required_envs([self.image_dataset]):
49544969
self.skipTest("missing required envs")
49554970

49564971
cmds = [
49574972
"python",
4958-
f"{self.executorch_root}/examples/qualcomm/oss_scripts/mobilevit_v1.py"
4973+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/mobilevit_v2.py",
49594974
"--dataset",
49604975
self.image_dataset,
49614976
"--artifact",
@@ -4973,6 +4988,8 @@ def test_mobilevit_v1(self):
49734988
]
49744989
if self.host:
49754990
cmds.extend(["--host", self.host])
4991+
if self.shared_buffer:
4992+
cmds.extend(["--shared_buffer"])
49764993

49774994
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
49784995
with Listener((self.ip, self.port)) as listener:
@@ -4982,17 +4999,16 @@ def test_mobilevit_v1(self):
49824999
if "Error" in msg:
49835000
self.fail(msg["Error"])
49845001
else:
4985-
self.assertGreaterEqual(msg["top_1"], 70)
5002+
self.assertGreaterEqual(msg["top_1"], 50)
49865003
self.assertGreaterEqual(msg["top_5"], 85)
49875004

4988-
@unittest.skip("Only outputs good accuracy in QNN 2.29")
4989-
def test_mobilevit_v2(self):
5005+
def test_mobilevit1(self):
49905006
if not self.required_envs([self.image_dataset]):
49915007
self.skipTest("missing required envs")
49925008

49935009
cmds = [
49945010
"python",
4995-
f"{self.executorch_root}/examples/qualcomm/oss_scripts/mobilevit_v2.py",
5011+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/mobilevit1.py",
49965012
"--dataset",
49975013
self.image_dataset,
49985014
"--artifact",
@@ -5010,8 +5026,6 @@ def test_mobilevit_v2(self):
50105026
]
50115027
if self.host:
50125028
cmds.extend(["--host", self.host])
5013-
if self.shared_buffer:
5014-
cmds.extend(["--shared_buffer"])
50155029

50165030
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
50175031
with Listener((self.ip, self.port)) as listener:
@@ -5021,7 +5035,7 @@ def test_mobilevit_v2(self):
50215035
if "Error" in msg:
50225036
self.fail(msg["Error"])
50235037
else:
5024-
self.assertGreaterEqual(msg["top_1"], 50)
5038+
self.assertGreaterEqual(msg["top_1"], 70)
50255039
self.assertGreaterEqual(msg["top_5"], 85)
50265040

50275041
def test_pvt(self):
@@ -5031,7 +5045,11 @@ def test_pvt(self):
50315045
cmds = [
50325046
"python",
50335047
f"{self.executorch_root}/examples/qualcomm/oss_scripts/pvt.py",
5048+
"--dataset",
50345049
self.image_dataset,
5050+
"--artifact",
5051+
self.artifact_dir,
5052+
"--build_folder",
50355053
self.build_folder,
50365054
"--device",
50375055
self.device,

backends/transforms/decompose_sdpa.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
# pyre-strict
88

9+
import math
10+
911
import torch
1012
from executorch.exir.pass_base import ExportPass, PassResult
1113
from torch._decomp import get_decompositions
@@ -30,6 +32,7 @@ def call(
3032
for node in graph.nodes:
3133
if node.target == torch.ops.aten.scaled_dot_product_attention.default:
3234
input_tensors = (arg.meta["val"] for arg in node.args)
35+
scale = node.kwargs.get("scale", None)
3336

3437
# refer to pytorch/test/test_decomp.py
3538
decomposed_module = make_fx(
@@ -81,6 +84,16 @@ def call(
8184
)
8285
continue
8386

87+
if scale is not None and decomposed_node.target in [
88+
torch.ops.aten.mul.Scalar
89+
]:
90+
new_args = list(decomposed_node.args)
91+
# Based on the implementation of _scaled_dot_product_attention_math,
92+
# the scale is applied to q and k before matmul.
93+
# refer to pytorch/aten/src/ATen/native/transformers/attention.cpp#L873
94+
new_args[1] = math.sqrt(scale)
95+
decomposed_node.args = tuple(new_args)
96+
8497
subgraph_node = graph.node_copy(
8598
decomposed_node,
8699
arg_transform=lambda x: decomposed_node_to_subgraph_node[ # noqa: B023

examples/qualcomm/oss_scripts/whisper/whisper.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@
3636

3737
from executorch.devtools.backend_debug import print_delegation_info
3838
from executorch.examples.qualcomm.oss_scripts.whisper.whisper_model import (
39-
Seq2SeqLMDecoderExportableModuleWithStaticCache,
40-
Seq2SeqLMEncoderExportableModule,
39+
QnnSeq2SeqLMDecoderExportableModuleWithStaticCache,
40+
QnnSeq2SeqLMEncoderExportableModule,
4141
)
4242

4343
from executorch.examples.qualcomm.utils import (
@@ -169,14 +169,14 @@ def __init__(
169169
)
170170

171171
self.whisper_encoder = (
172-
Seq2SeqLMEncoderExportableModule(whisper_model.get_encoder())
172+
QnnSeq2SeqLMEncoderExportableModule(whisper_model.get_encoder())
173173
.to("cpu")
174174
.eval()
175175
)
176176
self.encoder_passes_job = get_capture_program_passes()
177177

178178
self.whisper_decoder = (
179-
Seq2SeqLMDecoderExportableModuleWithStaticCache(
179+
QnnSeq2SeqLMDecoderExportableModuleWithStaticCache(
180180
whisper_model=whisper_model,
181181
max_cache_length=self.max_seq_length,
182182
batch_size=batch_size,
@@ -190,20 +190,21 @@ def __init__(
190190
self.exported_whisper_encoder = None
191191
self.exported_whisper_decoder = None
192192
self.has_quant_io = False
193+
self.kv_shape = {
194+
(self.max_seq_length, self.head_dim),
195+
}
193196

194197
def _tag_ios(self, node, fixed_point_type):
195198
if not self.has_quant_io:
196199
return
197200

198201
quant_io_type = None
199-
if node.op == "placeholder" and "static_cache_" in node.name:
202+
if node.op == "placeholder" and node.meta["val"].size()[-2:] in self.kv_shape:
200203
quant_io_type = fixed_point_type
201204

202205
if is_graph_output(node):
203206
# shape of k caches and v caches
204-
if node.meta["val"].size()[-2:] in {
205-
(self.max_seq_length, self.head_dim),
206-
}:
207+
if node.meta["val"].size()[-2:] in self.kv_shape:
207208
quant_io_type = fixed_point_type
208209

209210
return quant_io_type

examples/qualcomm/oss_scripts/whisper/whisper_model.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66

77

88
import torch
9-
from transformers import StaticCache, WhisperForConditionalGeneration
9+
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, StaticCache
10+
from transformers.models.whisper.modeling_whisper import WhisperForConditionalGeneration
1011

1112

12-
class Seq2SeqLMEncoderExportableModule(torch.nn.Module):
13+
class QnnSeq2SeqLMEncoderExportableModule(torch.nn.Module):
1314
"""
1415
A wrapper module designed to make a Seq2Seq LM encoder exportable with `torch.export`.
1516
This module ensures that the exported encoder model is compatible with ExecuTorch.
@@ -29,7 +30,7 @@ def get_metadata(self):
2930
return {}
3031

3132

32-
class Seq2SeqLMDecoderExportableModuleWithStaticCache(torch.nn.Module):
33+
class QnnSeq2SeqLMDecoderExportableModuleWithStaticCache(torch.nn.Module):
3334
"""
3435
A wrapper module designed to make a Seq2Seq LM decoder exportable with `torch.export`,
3536
specifically for use with static caching. This module ensures the exported decoder
@@ -57,11 +58,7 @@ def __init__(self, whisper_model, max_cache_length, batch_size):
5758
device="cpu",
5859
dtype=torch.float32,
5960
)
60-
61-
# Register cache buffers to make them exportable
62-
for i in range(len(self.static_cache.key_cache)):
63-
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i])
64-
self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i])
61+
self.cache = EncoderDecoderCache(self.static_cache, DynamicCache())
6562

6663
def forward(
6764
self, decoder_input_ids, attention_mask, encoder_hidden_states, cache_position
@@ -71,7 +68,7 @@ def forward(
7168
input_ids=decoder_input_ids,
7269
attention_mask=attention_mask,
7370
encoder_hidden_states=encoder_hidden_states,
74-
past_key_values=self.static_cache,
71+
past_key_values=self.cache,
7572
use_cache=True,
7673
cache_position=cache_position,
7774
)

0 commit comments

Comments
 (0)