Skip to content

Commit a6b6383

Browse files
usamahznikhil-armmalfet
authored andcommitted
[ARM] Improve LLM performance & mem usage using int4-bf16 KleidiAI kernels (pytorch#158250)
Co-authored-by: Nikhil Gupta [[email protected]](mailto:[email protected]) This PR enables the use of KleidiAI INT4 kernels that directly produce BF16 outputs within PyTorch to boost LLM prefill & decode performance **This change improves decode throughput by ~15% & reduces memory required to inference the model by 50%** ### Benchmark Setup ``` Model: meta-llama/Llama-3.1-8B Test Platform: Neoverse V2 ``` ### Detailed Results | Metric | With `--compile` | Without `--compile` | |----------------------------------|---------------------------|---------------------------| | Quantization Scheme | INT4 symmetric channelwise | INT4 symmetric channelwise | | Input Precision | BF16 | BF16 | | Number of Layers Quantized | 32 | 32 | | Average Compression Ratio | 87.49% | 87.49% | | Total Quantization Time (s) | 9.62 | 10.32 | | Compile Time (First) (s) | 134.48 | 1.69 | | Compile Time (Second) (s) | 80.44 | 1.60 | | Compile Time (Subsequent) (s) | 0.19 | 0.22 | | Prefill Tokens | 54 | 54 | | Decoded Tokens | 33 | 33 | | Prefill Time (s) | 0.19 | 0.22 | | Decode Time (s) | 0.76 | 1.38 | | E2E Generation Time (s) | 0.95 | 1.60 | | Prefill Throughput (tokens/s) | 288.13 | 249.91 | | Decode Throughput (tokens/s) | 43.42 | 23.83 | Pull Request resolved: pytorch#158250 Approved by: https://github.com/malfet, https://github.com/aditew01, https://github.com/fadara01 Co-authored-by: Nikhil Gupta <[email protected]> Co-authored-by: Nikita Shulga <[email protected]>
1 parent 21c11da commit a6b6383

File tree

9 files changed

+664
-137
lines changed

9 files changed

+664
-137
lines changed

aten/src/ATen/native/LinearAlgebra.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3554,9 +3554,9 @@ Tensor _dyn_quant_matmul_4bit_cpu(
35543554
const int64_t out_features) {
35553555
auto M = inp.size(0);
35563556
TORCH_CHECK(
3557-
inp.dtype() == kFloat,
3557+
inp.dtype() == kFloat || (inp.dtype() == kBFloat16 && block_size == in_features),
35583558
__func__,
3559-
" : expect input to be 32-bit float tensor.");
3559+
" : expect input to be float32 or bfloat16 tensor.");
35603560
TORCH_CHECK(
35613561
block_size == in_features ||
35623562
(!(block_size % 32) && !(in_features % block_size)),

0 commit comments

Comments
 (0)