Skip to content

Commit 1ab4dfe

Browse files
committed
Update on "[ET-VK][qlinear] Faster weight only quantized linear gemv kernel"
## Changes * Introduce a new compute shader for int4 linear's gemv cases that performs much better than the existing shader. This shader is inspired from MNN's gemv_1x1_conv_buf.cl shader. With this compute kernel, transformer models' text generation can execute much faster than before. On Samsung Galaxy S24 for Llama 3.2 1B, generating 128 tokens: Before: ~25 tok/s After: ~49 tok/s ## Why this new shader is faster The biggest reason is due to vectorized loading of the uint4 weight buffer. This new shader loads the weight buffer as a buffer/image of `uvec4`, whereas the old shader loads the weight buffer as a buffer/image of `u8vec4`. Using the Adreno Offline Compiler, I found that in the former, only one load instruction was used to load from the weight tensor, whereas in the latter 16 load instructions were used to load from the weight tensor. It appears that the data loading was not being vectorized at the assembly level. This is potentially behaviour that can be approved in the SPIR-V shader compiler. An additional factor is better weight packing layout. The new prepacking routine results in better memory coalescing between threads in a work group. The final major factor is the use of tree based reduction to co-operatively reduce partial results into the final output. Previously, a single thread was responsible for the final reduction. ## Future Work * Introduce faster shader for int4 linear gemm cases * Update QCSNW to also use these updated shaders Differential Revision: [D78275584](https://our.internmc.facebook.com/intern/diff/D78275584/) [ghstack-poisoned]
2 parents 7f88379 + ac53ab0 commit 1ab4dfe

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+2949
-581
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
7cda4017ddda554752e89069ae205be5e8388f59
1+
90f1e7bed15ca5e48c61c5b6dc5ad4810524f82f

.github/workflows/trunk.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ jobs:
197197
docker-image: executorch-ubuntu-22.04-arm-sdk
198198
submodules: 'recursive'
199199
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
200-
timeout: 90
200+
timeout: 120
201201
script: |
202202
# The generic Linux job chooses to use base env, not the one setup by the image
203203
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")

backends/arm/quantizer/arm_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060

6161
@functools.lru_cache
6262
def get_symmetric_quantization_config(
63-
is_per_channel: bool = False,
63+
is_per_channel: bool = True,
6464
is_qat: bool = False,
6565
is_dynamic: bool = False,
6666
act_qmin: int = -128,

backends/arm/test/common.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
arm_executor_runner_exists,
1919
corstone300_installed,
2020
corstone320_installed,
21+
model_converter_installed,
2122
)
2223
from executorch.backends.arm.tosa_specification import TosaSpecification
2324
from executorch.exir.backend.compile_spec_schema import CompileSpec
@@ -245,6 +246,13 @@ def get_u85_compile_spec_unbuilt(
245246
)
246247
"""Xfails a test if Corsone320 FVP is not installed, or if the executor runner is not built"""
247248

249+
SkipIfNoModelConverter = pytest.mark.skipif(
250+
condition=not (model_converter_installed()),
251+
raises=FileNotFoundError,
252+
reason="Did not find model-converter on path",
253+
)
254+
"""Xfails a test if model-converter is not installed"""
255+
248256
xfail_type = str | tuple[str, type[Exception]]
249257

250258

backends/arm/test/misc/test_bn_relu_folding_qat.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ def test_qat_tosa_BI(model: torch.nn.Module):
5959
"quantize",
6060
Quantize(
6161
quantizer=quantizer,
62-
quantization_config=get_symmetric_quantization_config(is_qat=True),
62+
quantization_config=get_symmetric_quantization_config(
63+
is_qat=True, is_per_channel=False
64+
),
6365
is_qat=True,
6466
),
6567
)

backends/arm/test/ops/test_add.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77

88
from typing import Tuple
99

10-
import pytest
11-
1210
import torch
1311
from executorch.backends.arm.arm_backend import get_tosa_spec
1412
from executorch.backends.arm.quantizer import arm_quantizer
@@ -190,7 +188,7 @@ def test_add_tensor_u85_BI_2(test_data: input_t2):
190188

191189

192190
@common.parametrize("test_data", Add.test_data)
193-
@pytest.mark.skip(reason="Model converter not yet made available")
191+
@common.SkipIfNoModelConverter
194192
def test_add_tensor_vgf_fp(test_data: input_t1):
195193
pipeline = VgfPipeline[input_t1](
196194
Add(), test_data(), aten_op, exir_op, tosa_version="TOSA-1.0+FP"
@@ -199,7 +197,7 @@ def test_add_tensor_vgf_fp(test_data: input_t1):
199197

200198

201199
@common.parametrize("test_data", Add.test_data)
202-
@pytest.mark.skip(reason="Model converter not yet made available")
200+
@common.SkipIfNoModelConverter
203201
def test_add_tensor_vgf_int(test_data: input_t1):
204202
pipeline = VgfPipeline[input_t1](
205203
Add(),

backends/arm/test/ops/test_multihead_attention.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,14 @@ def test_multihead_attention_tosa_MI(test_data: input_t1):
5353
)
5454
def test_multihead_attention_tosa_BI(test_data):
5555
test_data, module = test_data()
56-
pipeline = TosaPipelineBI(module, (*test_data, *test_data, *test_data), [], [])
56+
pipeline = TosaPipelineBI(
57+
module,
58+
(*test_data, *test_data, *test_data),
59+
[],
60+
[],
61+
# TODO: Per-channel quantization is broken (MLETORCH-1144)
62+
per_channel_quantization=False,
63+
)
5764
pipeline.run()
5865

5966

@@ -72,6 +79,8 @@ def test_multihead_attention_u55_BI(test_data: input_t1):
7279
[],
7380
use_to_edge_transform_and_lower=True,
7481
run_on_fvp=True,
82+
# TODO: Per-channel quantization is broken (MLETORCH-1144)
83+
per_channel_quantization=False,
7584
)
7685
pipeline.pop_stage("check_count.exir")
7786
pipeline.run()
@@ -92,5 +101,7 @@ def test_multihead_attention_u85_BI(test_data: input_t1):
92101
[],
93102
use_to_edge_transform_and_lower=True,
94103
run_on_fvp=True,
104+
# TODO: Per-channel quantization is broken (MLETORCH-1144)
105+
per_channel_quantization=False,
95106
)
96107
pipeline.run()

backends/arm/test/runner_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,15 @@ def corstone320_installed() -> bool:
549549
return True
550550

551551

552+
def model_converter_installed() -> bool:
553+
cmd = ["model-converter", "--version"]
554+
try:
555+
_run_cmd(cmd, check=True)
556+
except:
557+
return False
558+
return True
559+
560+
552561
def get_elf_path(target_board):
553562
elf_path = os.path.join(
554563
"arm_test",

backends/arm/test/test_arm_baremetal.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,9 @@ test_models_ethos-u55() { # End to End model tests using model_test.py
210210
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=ethos-u55-64 --model=mv3 --extra_flags="-DET_ATOL=5.00 -DET_RTOL=5.00"
211211
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=ethos-u55-256 --model=lstm --extra_flags="-DET_ATOL=0.03 -DET_RTOL=0.03"
212212
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=ethos-u55-128 --model=resnet18 --extra_flags="-DET_ATOL=0.2 -DET_RTOL=0.2"
213-
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=ethos-u55-128 --model=resnet50 --extra_flags="-DET_ATOL=0.2 -DET_RTOL=0.2"
213+
# TODO: Output performance for resnet50 is bad with per-channel quantization (MLETORCH-1149).
214+
# Also we get OOM when running this model. Disable it for now.
215+
#python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=ethos-u55-128 --model=resnet50 --extra_flags="-DET_ATOL=6.2 -DET_RTOL=6.2"
214216

215217
echo "${TEST_SUITE_NAME}: PASS"
216218
}

backends/arm/test/tester/test_pipeline.py

Lines changed: 25 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def __init__(
300300
run_on_tosa_ref_model: bool = True,
301301
tosa_version: str = "TOSA-0.80+BI",
302302
symmetric_io_quantization: bool = False,
303-
per_channel_quantization: bool = False,
303+
per_channel_quantization: bool = True,
304304
use_to_edge_transform_and_lower: bool = True,
305305
custom_path: str = None,
306306
atol: float = 1e-03,
@@ -317,16 +317,14 @@ def __init__(
317317
compile_spec = common.get_tosa_compile_spec(
318318
tosa_profiles[tosa_version], custom_path=custom_path
319319
)
320-
if symmetric_io_quantization or per_channel_quantization:
321-
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
322-
quantization_config = get_symmetric_quantization_config(
323-
is_per_channel=per_channel_quantization
324-
)
325-
if symmetric_io_quantization:
326-
quantizer.set_io(quantization_config)
327-
quant_stage = Quantize(quantizer, quantization_config)
328-
else:
329-
quant_stage = None
320+
321+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
322+
quantization_config = get_symmetric_quantization_config(
323+
is_per_channel=per_channel_quantization
324+
)
325+
if symmetric_io_quantization:
326+
quantizer.set_io(quantization_config)
327+
quant_stage = Quantize(quantizer, quantization_config)
330328

331329
super().__init__(
332330
module,
@@ -475,24 +473,21 @@ def __init__(
475473
exir_ops: Optional[str | List[str]] = None,
476474
run_on_fvp: bool = True,
477475
symmetric_io_quantization: bool = False,
478-
per_channel_quantization: bool = False,
476+
per_channel_quantization: bool = True,
479477
use_to_edge_transform_and_lower: bool = True,
480478
custom_path: str = None,
481479
atol: float = 1e-03,
482480
rtol: float = 1e-03,
483481
qtol: int = 1,
484482
):
485483
compile_spec = common.get_u55_compile_spec(custom_path=custom_path)
486-
if symmetric_io_quantization or per_channel_quantization:
487-
quantizer = EthosUQuantizer(compile_spec)
488-
quantization_config = get_symmetric_quantization_config(
489-
is_per_channel=per_channel_quantization
490-
)
491-
if symmetric_io_quantization:
492-
quantizer.set_io(quantization_config)
493-
quant_stage = Quantize(quantizer, quantization_config)
494-
else:
495-
quant_stage = None
484+
quantizer = EthosUQuantizer(compile_spec)
485+
quantization_config = get_symmetric_quantization_config(
486+
is_per_channel=per_channel_quantization
487+
)
488+
if symmetric_io_quantization:
489+
quantizer.set_io(quantization_config)
490+
quant_stage = Quantize(quantizer, quantization_config)
496491

497492
super().__init__(
498493
module,
@@ -565,24 +560,21 @@ def __init__(
565560
exir_ops: str | List[str] = None,
566561
run_on_fvp: bool = True,
567562
symmetric_io_quantization: bool = False,
568-
per_channel_quantization: bool = False,
563+
per_channel_quantization: bool = True,
569564
use_to_edge_transform_and_lower: bool = True,
570565
custom_path: str = None,
571566
atol: float = 1e-03,
572567
rtol: float = 1e-03,
573568
qtol: int = 1,
574569
):
575570
compile_spec = common.get_u85_compile_spec(custom_path=custom_path)
576-
if symmetric_io_quantization or per_channel_quantization:
577-
quantizer = EthosUQuantizer(compile_spec)
578-
quantization_config = get_symmetric_quantization_config(
579-
is_per_channel=per_channel_quantization
580-
)
581-
if symmetric_io_quantization:
582-
quantizer.set_io(quantization_config)
583-
quant_stage = Quantize(quantizer, quantization_config)
584-
else:
585-
quant_stage = None
571+
quantizer = EthosUQuantizer(compile_spec)
572+
quantization_config = get_symmetric_quantization_config(
573+
is_per_channel=per_channel_quantization
574+
)
575+
if symmetric_io_quantization:
576+
quantizer.set_io(quantization_config)
577+
quant_stage = Quantize(quantizer, quantization_config)
586578

587579
super().__init__(
588580
module,

0 commit comments

Comments
 (0)