Skip to content

Conversation

@shewu-quic
Copy link
Collaborator

Background

We observed that quantizing and compiling the original sha model requires a significant amount of time. Switching to the mha model speeds up this process. Therefore, we investigated whether converting the mha model after quantization is feasible. However, we cannot perform this conversion during the to_edge transformation, as splitting the convolution weights to sha would require modifying the state_dict, which is not permitted at that stage. Therefore, we decided to apply this pass during qnn_preprocess.

Summary:

  • Integrated mha into sha pass and implemented it in qnn_preprocess
  • Refactored mha in static llama
    • Included spin quant r3 support and masked softmax for MHA model in static llama
    • Combined the n_heads key-value cache into a single cache for each layer to decrease the number of inputs and outputs, which enhances performance.
  • Deprecated ShiftPointer kv updater mode
    • Since each layer now has its own kv cache, the v cache no longer benefits from ShiftPointer, which previously avoided copying the new v cache to the input v cache. To prevent user confusion, ShiftPointer mode has been deprecated
  • Applied the correct input template for smollm2 135m
  • Correct the quantization annotation for reshape
  • Remove outdated code from CanonicalizeConv

Results

Follow README setting, test on SM8750 with QNN 2.37. Compared the new pass convert_mha_to_sha with original sha structure

image

Summary:
- Integrated mha into sha pass and implemented it in qnn_preprocess
- Refactored mha in static llama
  - Added support for masked softmax
  - Included spin quant r3 support
  - Combined the n_heads key-value cache into a single cache for each
    layer to decrease the number of inputs and outputs, which enhances
performance.
- Deprecated ShiftPointer kv updater mode
  - Since each layer now has its own kv cache, the v cache no longer
    benefits from ShiftPointer, which previously avoided copying the new
v cache to the input v cache. To prevent user confusion, ShiftPointer
mode has been deprecated
- Applied the correct input template for smollm2 135m
- Corrected the quantization annotation for reshape
- Remove outdated code from CanonicalizeConv
@shewu-quic shewu-quic requested a review from cccclai as a code owner October 29, 2025 06:56
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 29, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/15438

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 18e7db1 with merge base 3485495 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 29, 2025
@shewu-quic
Copy link
Collaborator Author

@pytorchbot label "release notes: qualcomm"

@pytorch-bot pytorch-bot bot added the release notes: qualcomm Changes to the Qualcomm backend delegate label Oct 29, 2025
@shewu-quic
Copy link
Collaborator Author

Hi @cccclai,
This PR is to migrate mha2sha transformation from source level to a pass which apply on qnn_preprocess. It can significantly improve lowering time including quantization and compilation time.
Could you please take a look?

Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. release notes: qualcomm Changes to the Qualcomm backend delegate

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant