Skip to content

Conversation

@cavusmustafa
Copy link
Contributor

@cavusmustafa cavusmustafa commented Sep 5, 2025

Summary

This PR includes the changes below:

Release notes: backends

Test plan

The features provided in this PR are tested using the instructions provided in the readme files below. Export llama functionality with OpenVINO backend is not part of the backend test suite for the time being due to model file dependencies.
OpenVINO Backend Setup:
https://github.com/cavusmustafa/executorch/blob/openvino_llama_support/backends/openvino/README.md
OpenVINO Backend Example for Llama:
https://github.com/cavusmustafa/executorch/blob/openvino_llama_support/examples/openvino/llama/README.md

Performance Results

Model: meta-llama/Llama-3.2-1B-Instruct

tokens per second ms per token
XNNPACK CPU FP32 14.1 68
OpenVINO CPU FP32 16.6 60.3
OpenVINO CPU INT4 63.6 15.7
OpenVINO NPU INT4 42.9 23.3
OpenVINO NPU INT4 Channel Wise Quant 50.4 19.8

System Config:
Hardware: Intel(R) Core(TM) Ultra 7 258V (Intel Lunar Lake)
Architecture: x86_64, 8 cores, 4.8 GHz max
Memory: 32GB RAM
OS: Ubuntu 24.04

CC: @ynimmaga @suryasidd @anzr299 @daniil-lyakhov @MaximProshin @kimishpatel @cbilgin

cavusmustafa and others added 30 commits June 20, 2025 17:13
Update nncf_observers.py
@kimishpatel
Copy link
Contributor

Thanks for the PR. will do a review. I do see some CI failures.
Also why NPU int4 is actually slower? is there a config where NPU is actually faster than CPU

Hey Kimish, for NPU we have updated the numbers. We are seeing much better performance than previously reported with different config options. With channel wise quantization we are getting about 50 tok/s. We're currently working on additional optimizations to further improve NPU performance to get it closer to CPU.

Ok let me review the diff. However, I would expect NPU to significantly outperform CPU

@swolchok
Copy link
Contributor

swolchok commented Oct 3, 2025

However, I would expect NPU to significantly outperform CPU

for token generation, it depends on how the NPU accesses memory and thus the memory bandwidth, doesn't? Wouldn't the expected gain be in prompt processing? I only see one set of performance numbers, so I assume the numbers given are token generation rather than processing time for a longer prompt.

@kimishpatel
Copy link
Contributor

However, I would expect NPU to significantly outperform CPU

for token generation, it depends on how the NPU accesses memory and thus the memory bandwidth, doesn't? Wouldn't the expected gain be in prompt processing? I only see one set of performance numbers, so I assume the numbers given are token generation rather than processing time for a longer prompt.

Yes Scott. Thats right. But besides the issue you mentioned about not having reported prefill numbers, I also see taht NPU is slower than the CPU counterpart.

-DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \
-DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
-DEXECUTORCH_BUILD_OPENVINO_EXECUTOR_RUNNER=ON \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we remove the needs for separate runners it will also be easier since yoiu wont have to maintain your own build scripts. You should really only be requiring to install the deps

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes in openvino_build.sh we are just adding the normal executor now. This build script is just to make it easier for users to build the stack with openvino instead of manually using the cmake commands. This is the feedback we got from other users as well when they were trying out the build instructions for openvino.

Copy link
Contributor

@kimishpatel kimishpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall I have two major comments

  1. quantization: can we leverage existing quantization API instead of doing somethign custom
  2. For runner, why do we need a separate runner. Can we not use existing llm runner?

cc: @jackzhxng

@kimishpatel
Copy link
Contributor

@cavusmustafa please request review again once the comments are addressed

@cavusmustafa
Copy link
Contributor Author

@cavusmustafa please request review again once the comments are addressed

@kimishpatel, thank you for taking the time to review and share your feedback! We’ll revise and update based on your suggestions and let you know once the updates are ready.

CC: @suryasidd

Copy link
Contributor

@jackzhxng jackzhxng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything under examples/ looks good to me

@cavusmustafa
Copy link
Contributor Author

@cavusmustafa please request review again once the comments are addressed

@kimishpatel, we have updated the runner code, and it is ready for further review. Please let us know if any additional changes are needed. In the meantime, we are running further experiments with NPU and will share our findings soon. Thank you.

cc: @suryasidd

@suryasidd
Copy link
Contributor

@cavusmustafa please request review again once the comments are addressed

@kimishpatel, we have updated the runner code, and it is ready for further review. Please let us know if any additional changes are needed. In the meantime, we are running further experiments with NPU and will share our findings soon. Thank you.

cc: @suryasidd

Added one more commit with very minor change. Changing the quantization scheme for 4 bit to Symmetric to extract more performance for Llama models on NPU and ratio to 1 to compress all the nodes.

@mergennachin
Copy link
Contributor

@cavusmustafa don't merge just yet, let me import and do some sanity check on our internal unit tests

@cavusmustafa
Copy link
Contributor Author

@cavusmustafa don't merge just yet, let me import and do some sanity check on our internal unit tests

Sure @kimishpatel, thank you for reviewing the PR. Let us know of any issues

@meta-codesync
Copy link

meta-codesync bot commented Oct 13, 2025

@mergennachin has imported this pull request. If you are a Meta employee, you can view this in D84544001.

@ynimmaga ynimmaga added the partner: intel For backend delegation, kernels, demo, etc. from the 3rd-party partner, Intel label Oct 13, 2025
@meta-codesync meta-codesync bot merged commit e731449 into pytorch:main Oct 14, 2025
258 of 315 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. partner: intel For backend delegation, kernels, demo, etc. from the 3rd-party partner, Intel

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants