Skip to content

Commit d67426f

Browse files
author
ssjia
committed
Update on "[ET-VK] Add kInt8x4 dtype and GPUMemoryLayouts for packed quantized tensors"
## Motivation Lay the foundations for being able to execute statically quantized CNNs with ET-VK. Unlike with dynamic quantization, static quantization allows the output of quantized operators to stay in integer representation and be fed directly to the next quantized operator. ## Context Typically, int8 quantized tensors can be represented by simply having the tensor use the int8 data type. While this is possible in ET-VK, in practice quantized operators expect int8 quantized tensors to be packed so that 16 8-bit values are packed into each `ivec4`, such that quantized int8 tensors will load/store with a granularity of 16 elements. The reason for this is twofold: * Support for shader int8 / storage buffer int8 extension is not guaranteed, meaning some devices do not allow using int8 types in shaders * We have found that load/store from storage buffers/textures that use int8 data types sometimes results in worse memory load performance, due to vectorized load/store instructions not being used. Therefore, in ET-VK we need a way to mark that a quantized tensor should 1. Use int32 as the underlying data type for the storage buffer/texture 2. Account for the block-packing that may be used ## Changes First, introduce the `Int8x4` dtype that can be used for packed int8 tensors. This dtype is functionally the same as `Int`, but denotes that each int32 actually contains 4 packed 8-bit values. Second, introduce new memory layouts: `kPackedInt8_4W4C` and `kPackedInt8_4H4W`. The former will be used for convolution, whil the latter will be used for matrix multiplication. See the inline comments for more details about these memory layouts. Then, update `QuantizedConvolution.cpp` and `QuantizedLinear.cpp` to use the new data type and memory layouts for the packed int8 input tensor. Differential Revision: [D82542336](https://our.internmc.facebook.com/intern/diff/D82542336/) [ghstack-poisoned]
2 parents ff401cb + 1d26e94 commit d67426f

File tree

46 files changed

+794
-659
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+794
-659
lines changed

.ci/scripts/test_backend_linux.sh

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,17 @@ if [[ "$FLOW" == *qnn* ]]; then
3939
fi
4040

4141
if [[ "$FLOW" == *vulkan* ]]; then
42-
# Setup swiftshader and Vulkan SDK which are required to build the Vulkan delegate
42+
# Setup swiftshader and Vulkan SDK which are required to build the Vulkan delegate.
4343
source .ci/scripts/setup-vulkan-linux-deps.sh
4444

4545
EXTRA_BUILD_ARGS+=" -DEXECUTORCH_BUILD_VULKAN=ON"
4646
fi
4747

48+
if [[ "$FLOW" == *arm* ]]; then
49+
# Setup ARM deps.
50+
.ci/scripts/setup-arm-baremetal-tools.sh
51+
fi
52+
4853
# We need the runner to test the built library.
4954
PYTHON_EXECUTABLE=python CMAKE_ARGS="$EXTRA_BUILD_ARGS" .ci/scripts/setup-linux.sh --build-tool cmake --build-mode Release --editable true
5055

.ci/scripts/test_llava.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ run_and_verify() {
149149

150150
# verify result.txt
151151
RESULT=$(cat result.txt)
152-
EXPECTED_PREFIX="ASSISTANT: image captures a basketball game in progress, with"
152+
EXPECTED_PREFIX="ASSISTANT: The image captures a basketball game in progress, with"
153153

154154
if [[ "${RESULT}" == *"${EXPECTED_PREFIX}"* ]]; then
155155
echo "Expected result prefix: ${EXPECTED_PREFIX}"

.github/workflows/_link_check.yml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,29 @@ jobs:
5555
echo "Or add \`@lint-ignore\` somewhere on the same line as the reference you want to skip checking."
5656
exit 1
5757
}
58+
59+
lint-file-size:
60+
if: ${{ github.event_name == 'pull_request' }}
61+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
62+
with:
63+
runner: linux.2xlarge
64+
docker-image: ci-image:executorch-ubuntu-22.04-linter
65+
submodules: false
66+
fetch-depth: 0
67+
ref: ${{ inputs.ref }}
68+
timeout: 30
69+
script: |
70+
chmod +x ./scripts/lint_file_size.sh
71+
./scripts/lint_file_size.sh $(
72+
if [ "${{ github.event_name }}" = "pull_request" ]; then
73+
echo "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}"
74+
else
75+
echo "${{ github.event.before }}" "${{ github.sha }}"
76+
fi
77+
) || {
78+
echo
79+
echo "File size lint failed: some files exceed the 1 MB limit."
80+
echo "If you really need large files, consider using Git LFS or storing them elsewhere."
81+
echo "If you really need to get unblocked and check in the file, can add it to the EXCEPTIONS list in scripts/lint_file_size.sh."
82+
exit 1
83+
}

.github/workflows/_test_backend.yml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ on:
3131
required: false
3232
type: boolean
3333
default: false
34+
runner-linux:
35+
description: 'Runner type for Linux jobs'
36+
required: false
37+
type: string
38+
default: linux.4xlarge.memory
3439

3540
jobs:
3641
test-backend-linux:
@@ -44,7 +49,7 @@ jobs:
4449
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
4550
with:
4651
ref: ${{ inputs.ref }}
47-
runner: linux.4xlarge.memory
52+
runner: ${{ inputs.runner-linux }}
4853
docker-image: ci-image:executorch-ubuntu-22.04-clang12
4954
submodules: recursive
5055
timeout: ${{ inputs.timeout }}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
name: Test ARM Backend
2+
3+
on:
4+
schedule:
5+
- cron: 0 2 * * *
6+
push:
7+
tags:
8+
- ciflow/nightly/*
9+
pull_request:
10+
paths:
11+
- .github/workflows/test-backend-arm.yml
12+
- .github/workflows/_test_backend.yml
13+
workflow_dispatch:
14+
15+
concurrency:
16+
group: ${{ github.workflow }}--${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
17+
cancel-in-progress: true
18+
19+
jobs:
20+
test-arm:
21+
uses: ./.github/workflows/_test_backend.yml
22+
with:
23+
backend: arm
24+
flows: '["arm_tosa"]'
25+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
26+
timeout: 120
27+
run-linux: true

.github/workflows/test-backend-qnn.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@ jobs:
2525
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
2626
timeout: 120
2727
run-linux: true
28+
runner-linux: linux.8xlarge.memory

backends/arm/test/tester/arm_tester.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757

5858
from executorch.backends.arm.vgf import VgfCompileSpec, VgfPartitioner
5959

60+
from executorch.backends.test.harness.error_statistics import ErrorStatistics
6061
from executorch.backends.test.harness.stages import Stage, StageType
6162
from executorch.backends.xnnpack.test.tester import Tester
6263
from executorch.devtools.backend_debug import get_delegation_info
@@ -333,6 +334,7 @@ def to_edge_transform_and_lower(
333334
transform_passes: Optional[
334335
Union[Sequence[PassType], Dict[str, Sequence[PassType]]]
335336
] = None,
337+
generate_etrecord: bool = False,
336338
):
337339
if transform_passes is not None:
338340
raise RuntimeError(
@@ -367,7 +369,9 @@ def to_edge_transform_and_lower(
367369
to_edge_and_lower_stage.partitioners = partitioners
368370
if edge_compile_config is not None:
369371
to_edge_and_lower_stage.edge_compile_conf = edge_compile_config
370-
return super().to_edge_transform_and_lower(to_edge_and_lower_stage)
372+
return super().to_edge_transform_and_lower(
373+
to_edge_and_lower_stage, generate_etrecord=generate_etrecord
374+
)
371375

372376
def to_executorch(self, to_executorch_stage: Optional[ToExecutorch] | None = None):
373377
if to_executorch_stage is None:
@@ -402,6 +406,7 @@ def run_method_and_compare_outputs(
402406
qtol=0,
403407
error_callbacks=None,
404408
run_eager_mode=False,
409+
statistics_callback: Callable[[ErrorStatistics], None] | None = None,
405410
):
406411
"""
407412
Compares the run_artifact output of 'stage' with the output of a reference stage.
@@ -657,10 +662,17 @@ def _compare_outputs(
657662
rtol=1e-03,
658663
qtol=0,
659664
error_callbacks=None,
665+
statistics_callback: Callable[[ErrorStatistics], None] | None = None,
660666
):
661667
try:
662668
super()._compare_outputs(
663-
reference_output, stage_output, quantization_scale, atol, rtol, qtol
669+
reference_output,
670+
stage_output,
671+
quantization_scale,
672+
atol,
673+
rtol,
674+
qtol,
675+
statistics_callback=statistics_callback,
664676
)
665677
except AssertionError as e:
666678
if error_callbacks is None:

backends/cadence/aot/TARGETS

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,23 @@ executorch_generated_lib(
143143
visibility = ["PUBLIC"],
144144
deps = [
145145
"//executorch/backends/cadence/generic/kernels:cadence_kernels",
146-
"//executorch/backends/cadence/generic/operators:cadence_generic_ops",
146+
# Individual operator targets instead of combined cadence_generic_ops
147+
"//executorch/backends/cadence/generic/operators:op_add",
148+
"//executorch/backends/cadence/generic/operators:op_embedding",
149+
"//executorch/backends/cadence/generic/operators:op_full",
150+
"//executorch/backends/cadence/generic/operators:op_requantize_out",
151+
"//executorch/backends/cadence/generic/operators:op_view_copy",
152+
"//executorch/backends/cadence/generic/operators:im2row_out",
153+
"//executorch/backends/cadence/generic/operators:dequantize_per_tensor",
154+
"//executorch/backends/cadence/generic/operators:quantize_per_tensor",
155+
"//executorch/backends/cadence/generic/operators:quantized_add_out",
156+
"//executorch/backends/cadence/generic/operators:quantized_conv_nchw_out",
157+
"//executorch/backends/cadence/generic/operators:quantized_conv_nhwc_out",
158+
"//executorch/backends/cadence/generic/operators:quantized_fully_connected_out",
159+
"//executorch/backends/cadence/generic/operators:quantized_layer_norm",
160+
"//executorch/backends/cadence/generic/operators:quantized_linear_out",
161+
"//executorch/backends/cadence/generic/operators:quantized_matmul_out",
162+
"//executorch/backends/cadence/generic/operators:quantized_relu_out",
147163
"//executorch/kernels/portable:executorch_all_ops",
148164
"//executorch/kernels/portable:operators",
149165
],

backends/cadence/aot/ops_registrations.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,19 @@
324324
"rope.out(Tensor input, Tensor sin_tensor, Tensor cos_tensor, Tensor? pos, *, Tensor(a!) out) -> Tensor(a!)"
325325
)
326326

327+
lib.define(
328+
"quantized_softmax(Tensor input, Tensor mask, int dim, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point) -> (Tensor out)"
329+
)
330+
lib.define(
331+
"quantized_softmax.per_tensor(Tensor input, Tensor mask, int dim, float in_scale, int in_zero_point, float out_scale, int out_zero_point) -> (Tensor out)"
332+
)
333+
lib.define(
334+
"quantized_softmax.out(Tensor input, Tensor mask, int dim, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
335+
)
336+
lib.define(
337+
"quantized_softmax.per_tensor_out(Tensor input, Tensor mask, int dim, float in_scale, int in_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
338+
)
339+
327340
# Load/store with iDMA. These only exist before memory planning.
328341
# Post memory planning, we check that outputs/inputs for the load/store are in
329342
# DTCM and replace idma_load/idma_store with idma_copy.
@@ -2329,3 +2342,29 @@ def softmax_f32_f32_meta(
23292342
half_to_float: Optional[bool] = None,
23302343
) -> torch.Tensor:
23312344
return self.new_empty(self.size(), dtype=self.dtype)
2345+
2346+
2347+
@register_fake("cadence::quantized_softmax")
2348+
def quantized_softmax_meta(
2349+
input: torch.Tensor,
2350+
mask: torch.Tensor,
2351+
dim: int,
2352+
in_scale: torch.Tensor,
2353+
in_zero_point: torch.Tensor,
2354+
out_scale: torch.Tensor,
2355+
out_zero_point: torch.Tensor,
2356+
) -> torch.Tensor:
2357+
return input.new_empty(input.size(), dtype=input.dtype)
2358+
2359+
2360+
@register_fake("cadence::quantized_softmax.per_tensor")
2361+
def quantized_softmax_per_tensor_meta(
2362+
input: torch.Tensor,
2363+
mask: torch.Tensor,
2364+
dim: int,
2365+
in_scale: float,
2366+
in_zero_point: int,
2367+
out_scale: float,
2368+
out_zero_point: int,
2369+
) -> torch.Tensor:
2370+
return input.new_empty(input.size(), dtype=input.dtype)

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66

77
# pyre-strict
88

9-
from typing import Any, Dict, List, Tuple
9+
from typing import Any, cast, Dict, List, Tuple
1010

1111
import torch
12+
from executorch.backends.cadence.aot.compiler_utils import get_shape
1213
from executorch.backends.cadence.aot.quantizer.patterns import (
1314
AddmmPattern,
1415
AddPattern,
@@ -25,6 +26,7 @@
2526
MatmulPattern,
2627
ReluPattern0,
2728
ReluPattern1,
29+
SoftmaxPattern,
2830
)
2931
from executorch.backends.cadence.aot.quantizer.utils import (
3032
check_out_zero_point_is_min_range,
@@ -388,6 +390,73 @@ def get_args_and_kwargs_relu(
388390
return args, kwargs
389391

390392

393+
def get_args_and_kwargs_softmax(
394+
graph_module: GraphModule,
395+
inputs_inputs: List[fx.Node],
396+
dequants_inputs: List[fx.Node],
397+
quant_node: fx.Node,
398+
op_node: fx.Node,
399+
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
400+
# Make a dummy mask tensor
401+
mask_shape = get_shape(graph_module, cast(fx.Node, quant_node.args[0]))
402+
mask_shape = list(mask_shape) if mask_shape else []
403+
mask_shape[-1] = mask_shape[-1] // 16
404+
mask_tensor = graph_module.graph.call_function(
405+
torch.ops.aten.full.default,
406+
(
407+
mask_shape,
408+
0.0,
409+
),
410+
{"dtype": torch.int32},
411+
)
412+
# Make the scale and zero_point tensors
413+
in_scale_tensor = graph_module.graph.call_function(
414+
torch.ops.aten.full.default,
415+
(
416+
[1],
417+
dequants_inputs[0].args[1],
418+
),
419+
{"dtype": torch.float32},
420+
)
421+
in_zero_point_tensor = graph_module.graph.call_function(
422+
torch.ops.aten.full.default,
423+
(
424+
[1],
425+
dequants_inputs[0].args[2],
426+
),
427+
{"dtype": torch.int32},
428+
)
429+
out_scale_tensor = graph_module.graph.call_function(
430+
torch.ops.aten.full.default,
431+
(
432+
[1],
433+
quant_node.args[1],
434+
),
435+
{"dtype": torch.float32},
436+
)
437+
out_zero_point_tensor = graph_module.graph.call_function(
438+
torch.ops.aten.full.default,
439+
(
440+
[1],
441+
quant_node.args[2],
442+
),
443+
{"dtype": torch.int32},
444+
)
445+
446+
# Make the args and kwargs for the replacement op
447+
args = (
448+
inputs_inputs[0],
449+
mask_tensor,
450+
op_node.args[1],
451+
in_scale_tensor,
452+
in_zero_point_tensor,
453+
out_scale_tensor,
454+
out_zero_point_tensor,
455+
)
456+
kwargs = {}
457+
return args, kwargs
458+
459+
391460
class QuantFusion(ExportPass):
392461
# pyre-ignore[2]: Parameter `patterns` has no type specified
393462
def __init__(self, patterns) -> None:
@@ -543,6 +612,14 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
543612
dequants_inputs,
544613
quant_node,
545614
)
615+
elif isinstance(pattern, SoftmaxPattern):
616+
args, kwargs = get_args_and_kwargs_softmax(
617+
graph_module,
618+
inputs_inputs,
619+
dequants_inputs,
620+
quant_node,
621+
anchor_output_node,
622+
)
546623
fused = graph_module.graph.call_function(
547624
pattern.replacement_op(),
548625
args,

0 commit comments

Comments
 (0)