Skip to content

Commit 9308859

Browse files
author
morelos
committed
Update base for Update on "[ET-VK][Ops] linear_qta8a_qga4w_qta8o impl and shaders"
# Operator Description The linear_qta8a_qga4w_qta8o operator implements a quantized linear transformation that enables efficient neural network inference through dynamic quantization. This operator performs matrix multiplication between quantized 8-bit activations and 4-bit grouped quantized weights, producing quantized 8-bit outputs. The quantization scheme follows the standard affine mapping where `real_value = scale * (quantized_value - zero_point)`. Input activations use 8-bit signed integers with per-token scale and zero-point parameters, while weights employ 4-bit quantization with group-wise parameters. # Implementation Architecture The operator provides two distinct computational approaches optimized for different matrix multiplication scenarios: the TILED algorithm for general matrix-matrix multiplication (GEMM) and the COOPERATIVE algorithm for matrix-vector multiplication (GEMV). ## TILED Algorithm (GEMM Cases) The tiled implementation processes the output matrix in rectangular blocks. Each thread is responsible for calculating a tile of output values, typically processing 3 rows and 2 columns worth of results in each iteration. The algorithm operates by having each thread load blocks of quantized weights and activations, perform integer arithmetic accumulation, and then apply the necessary scaling operations. Weight data is pre-packed in a specialized format where two 4-bit values are stored in each byte. Each thread loads multiple weight elements simultaneously and unpacks them during computation. The quantization parameters for weights are organized by groups, where each group of consecutive weight elements shares the same scale and zero-point values. ## COOPERATIVE Algorithm (GEMV Cases) The cooperative implementation uses shared memory and thread cooperation where this approach uses workgroups of 64 threads arranged as 8 groups of 8 workers each. The key insight is that GEMV operations have limited parallelism in the output dimension but substantial parallelism in the reduction dimension, making cooperative reduction strategies more effective than independent thread computation. Each group of 8 worker threads collaboratively computes a portion of the output vector. The workers divide the reduction work along the input feature dimension, with each worker processing every 8th element in a strided pattern. # Future Performance Improvements - Making use of dotPacked4x8EXT (this requires upgrading glslc and vulkan) - Fixed point math for pure integer operations - Might be more performant to avoid preloading tensors - Might also be more performant to avoid excessive register overhead by defining the ivec4 within each block operation (allowing more threads to be more register intensive) Differential Revision: [D77173441](https://our.internmc.facebook.com/intern/diff/D77173441/) [ghstack-poisoned]
2 parents 8a8baa8 + dd06b3b commit 9308859

File tree

77 files changed

+1210
-282
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

77 files changed

+1210
-282
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
7cda4017ddda554752e89069ae205be5e8388f59
1+
9b498d3bb28b8e3411ce464dd2755c5b96d92c8f

.ci/scripts/check_c10_sync.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ pushd pytorch
1212
git checkout "$pytorch_pin"
1313
popd
1414
"$(dirname "${BASH_SOURCE[0]}")"/compare_dirs.sh runtime/core/portable_type/c10/c10 pytorch/c10
15-
"$(dirname "${BASH_SOURCE[0]}")"/compare_dirs.sh runtime/core/portable_type/c10/torch/headeronly pytorch/torch/headeronly
15+
"$(dirname "${BASH_SOURCE[0]}")"/compare_dirs.sh runtime/core/portable_type/c10/torch/standalone pytorch/torch/standalone

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
# LICENSE file in the root directory of this source tree.
77

88
# pyre-unsafe
9+
10+
import executorch.backends.arm.tosa.dialect # noqa: unused
911
from executorch.backends.arm._passes import (
1012
AddBiasPass,
1113
AnnotateChannelsLastDimOrder,

backends/arm/test/passes/test_decorate_fp32_to_int32_casting_pass.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Tuple
77

88
import torch
9-
from executorch.backends.arm.test import common
9+
from executorch.backends.arm.test import common, conftest
1010

1111
from executorch.backends.arm.test.tester.test_pipeline import (
1212
OpNotSupportedPipeline,
@@ -55,6 +55,7 @@ def test_decorate_fp32_to_int32_casting_tosa_MI(test_data: Tuple):
5555
(test_tensor,),
5656
aten_op=[],
5757
exir_op=[],
58+
run_on_tosa_ref_model=conftest.is_option_enabled("tosa_ref_model"),
5859
)
5960
pipeline.run()
6061

backends/cadence/fusion_g3/operators/op_exp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ Tensor& exp_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
6060
return out;
6161
} else {
6262
return torch::executor::native::internal::
63-
unary_ufunc_realhbbf16_to_floathbf16(std::exp, ctx, in, out);
63+
unary_ufunc_realhbbf16_to_floathbf16(std::exp, std::exp, ctx, in, out);
6464
}
6565
}
6666

backends/cadence/fusion_g3/operators/op_rsqrt.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ namespace native {
2727

2828
namespace {
2929

30-
double rsqrt(double x) {
30+
template <typename T>
31+
T rsqrt(T x) {
3132
return 1.0 / std::sqrt(x);
3233
}
3334

@@ -61,11 +62,11 @@ Tensor& rsqrt_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
6162
return out;
6263
} else {
6364
return torch::executor::native::internal::
64-
unary_ufunc_realhbbf16_to_floathbf16(rsqrt, ctx, in, out);
65+
unary_ufunc_realhbbf16_to_floathbf16(rsqrt, rsqrt, ctx, in, out);
6566
}
6667
}
6768

6869
} // namespace native
6970
} // namespace G3
7071
} // namespace impl
71-
} // namespace cadence
72+
} // namespace cadence

backends/cadence/fusion_g3/operators/op_sqrt.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ Tensor& sqrt_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
5555
return out;
5656
} else {
5757
return torch::executor::native::internal::
58-
unary_ufunc_realhbbf16_to_floathbf16(std::sqrt, ctx, in, out);
58+
unary_ufunc_realhbbf16_to_floathbf16(
59+
std::sqrt, std::sqrt, ctx, in, out);
5960
}
6061
}
6162

backends/cadence/fusion_g3/operators/op_tanh.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ Tensor& tanh_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
5555
return out;
5656
} else {
5757
return torch::executor::native::internal::
58-
unary_ufunc_realhbbf16_to_floathbf16(std::tanh, ctx, in, out);
58+
unary_ufunc_realhbbf16_to_floathbf16(
59+
std::tanh, std::tanh, ctx, in, out);
5960
}
6061
}
6162

backends/cadence/hifi/operators/op_rsqrt.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ namespace HiFi {
2121
namespace native {
2222
namespace {
2323

24-
double rsqrt(double x) {
24+
template <typename T>
25+
T rsqrt(T x) {
2526
return 1.0 / std::sqrt(x);
2627
}
2728

@@ -46,7 +47,7 @@ Tensor& rsqrt_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) {
4647
}
4748

4849
return torch::executor::native::internal::
49-
unary_ufunc_realhbbf16_to_floathbf16(rsqrt, ctx, in, out);
50+
unary_ufunc_realhbbf16_to_floathbf16(rsqrt, rsqrt, ctx, in, out);
5051
}
5152

5253
} // namespace native

backends/cadence/hifi/operators/op_tanh.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@ Tensor& tanh_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) {
3535
}
3636

3737
return torch::executor::native::internal::
38-
unary_ufunc_realhbbf16_to_floathbf16(std::tanh, ctx, in, out);
38+
unary_ufunc_realhbbf16_to_floathbf16(std::tanh, std::tanh, ctx, in, out);
3939
}
4040

4141
} // namespace native
4242
} // namespace HiFi
4343
} // namespace impl
44-
} // namespace cadence
44+
} // namespace cadence

0 commit comments

Comments
 (0)