Skip to content

Commit c2e3ba5

Browse files
authored
Merge branch 'main' into dev1/winskuo/custom_annotation_fix
2 parents baa8c8f + be221c6 commit c2e3ba5

File tree

4 files changed

+174
-40
lines changed

4 files changed

+174
-40
lines changed

.github/workflows/trunk.yml

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ jobs:
6060
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
6161
strategy:
6262
matrix:
63-
model: [add]
63+
model: [add, softmax, mv2]
6464
fail-fast: false
6565
with:
6666
runner: linux.2xlarge
@@ -72,6 +72,16 @@ jobs:
7272
MODEL_NAME=${{ matrix.model }}
7373
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
7474
conda activate "${CONDA_ENV}"
75+
if [[ ${{ matrix.model}} == "add" ]]; then
76+
SIM_LIMIT_SEC=60
77+
elif [[ ${{ matrix.model}} == "softmax" ]]; then
78+
SIM_LIMIT_SEC=60
79+
elif [[ ${{ matrix.model}} == "mv2" ]]; then
80+
SIM_LIMIT_SEC=5000
81+
else
82+
echo "Failed unsupported model selection ${{ matrix.model }}"
83+
exit 1
84+
fi
7585
7686
source .ci/scripts/utils.sh
7787
source .ci/scripts/zephyr-utils.sh
@@ -80,20 +90,23 @@ jobs:
8090
export ZEPHYR_PROJ_ROOT=$(realpath $(pwd))
8191
export ARM_FVP_TUTORIALS_ROOT=$ZEPHYR_PROJ_ROOT/zephyr/samples/modules/executorch/arm-fvp-tutorials
8292
93+
# TODO @Bujji: Should see if this can be moved into the docker image itself
8394
download_arm_zephyr_sdk
8495
./zephyr-sdk-0.16.0/setup.sh -c -t arm-zephyr-eabi
85-
8696
cd $ZEPHYR_PROJ_ROOT
8797
setup_zephyr_et_module
8898
99+
# Run setup scripts for Arm FVP and Arm AOT Compilation
89100
cd $ZEPHYR_PROJ_ROOT/modules/lib/executorch
90101
install_executorch "--use-pt-pinned-commit"
91102
.ci/scripts/setup-arm-baremetal-tools.sh --target-toolchain zephyr
92103
source examples/arm/ethos-u-scratch/setup_path.sh
93104
source $ZEPHYR_PROJ_ROOT/zephyr/zephyr-env.sh
94105
95106
# Get the model as PTE
96-
python -m examples.arm.aot_arm_compiler --model_name="${MODEL_NAME}" --output="${MODEL_NAME}.pte"
107+
python -m examples.arm.aot_arm_compiler \
108+
--model_name="${MODEL_NAME}" \
109+
--output="${MODEL_NAME}.pte"
97110
98111
# Generate the C-style header
99112
cd $ARM_FVP_TUTORIALS_ROOT
@@ -105,7 +118,8 @@ jobs:
105118
cd $ARM_FVP_TUTORIALS_ROOT/models/${MODEL_NAME}/
106119
107120
# Build the zephyr elf
108-
west build -p always -b mps3/corstone300/fvp
121+
west build -p always -b mps3/corstone300/fvp -- \
122+
-DET_PTE_FILE_PATH_FOR_SELECTIVE_BUILD=$ZEPHYR_PROJ_ROOT/modules/lib/executorch/${MODEL_NAME}.pte
109123
110124
# Run the simulation
111125
FVP_Corstone_SSE-300_Ethos-U55 -a build/zephyr/zephyr.elf \
@@ -114,23 +128,29 @@ jobs:
114128
-C mps3_board.uart0.out_file='sim.out' \
115129
-C cpu0.CFGITCMSZ=15 \
116130
-C cpu0.CFGDTCMSZ=15 \
117-
--simlimit 120
131+
--simlimit ${SIM_LIMIT_SEC}
118132
133+
# Disable exit on error
134+
set +e
119135
# Report failure if any of the ouptut verification checks fail
120136
grep -qF "ERROR" sim.out
121137
exit_status=$? #store 0 if found (failure), 1 if not (success)
122138
if [[ "$exit_status" -eq "0" ]]; then
123-
cat sim.out
124-
exit 1
139+
cat sim.out
140+
set -e
141+
exit 1
125142
fi
126143
127144
# Report fail if simulation does not complete successfully
128145
grep -qF "SUCCESS: Program complete, exiting." sim.out
129146
exit_status=$? #store 0 if found (success), 1 if not (failure)
130147
if [[ "$exit_status" -eq "1" ]]; then
131-
cat sim.out
132-
exit 1
148+
cat sim.out
149+
set -e
150+
exit 1
133151
fi
152+
# Re-enable exit on error
153+
set -e
134154
135155
test-models-linux-aarch64:
136156
name: test-models-linux-aarch64

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212
)
1313
from executorch.backends.qualcomm.quantizer.quantizer import (
1414
get_16a8w_qnn_ptq_config,
15+
get_16a8w_qnn_qat_config,
1516
get_8a8w_qnn_ptq_config,
17+
get_8a8w_qnn_qat_config,
1618
get_ptq_per_channel_quant_config,
19+
get_qat_per_channel_quant_config,
1720
QuantizationConfig,
1821
)
1922
from executorch.exir.dialects._ops import ops as exir_ops
@@ -154,7 +157,9 @@ def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict):
154157

155158

156159
def annotate_matmul_16a8w( # noqa: C901
157-
gm: torch.fx.GraphModule, annotate_conv=True
160+
gm: torch.fx.GraphModule,
161+
annotate_conv=True,
162+
is_qat=False,
158163
) -> None:
159164
"""
160165
This function is specific for matmul op 16a8w.
@@ -242,7 +247,6 @@ def annotate_rms_norm(node: Node, quantization_config: QuantizationConfig) -> No
242247
def annotate_single_in_single_out(
243248
node: Node, quantization_config: QuantizationConfig
244249
) -> None:
245-
246250
input_qspec_map = {}
247251
input_act = node.args[0]
248252
input_qspec_map[input_act] = quantization_config.input_activation
@@ -256,7 +260,6 @@ def annotate_single_in_single_out(
256260
def annotate_single_in_share_out(
257261
node: Node, quantization_config: QuantizationConfig
258262
) -> None:
259-
260263
input_qspec_map = {}
261264
input_act = node.args[0]
262265
input_qspec_map[input_act] = quantization_config.input_activation
@@ -287,16 +290,27 @@ def annotate_stack(node: Node, quantization_config: QuantizationConfig) -> None:
287290
_annotated=True,
288291
)
289292

290-
def annotate_matmul_input1(node: Node):
291-
quantization_config_8a8w = get_8a8w_qnn_ptq_config(
292-
act_symmetric=True, act_observer=MinMaxObserver
293-
)
294-
quantization_config_8a4w_per_channel = get_ptq_per_channel_quant_config(
295-
act_dtype=torch.uint8,
296-
weight_dtype=torch.int4,
297-
act_observer=MinMaxObserver,
298-
act_symmetric=True,
299-
)
293+
def annotate_matmul_input1(node: Node, is_qat: str):
294+
if is_qat:
295+
quantization_config_8a8w = get_8a8w_qnn_qat_config(
296+
act_symmetric=True, act_observer=MinMaxObserver
297+
)
298+
quantization_config_8a4w_per_channel = get_qat_per_channel_quant_config(
299+
act_dtype=torch.uint8,
300+
weight_dtype=torch.int4,
301+
act_observer=MinMaxObserver,
302+
act_symmetric=True,
303+
)
304+
else:
305+
quantization_config_8a8w = get_8a8w_qnn_ptq_config(
306+
act_symmetric=True, act_observer=MinMaxObserver
307+
)
308+
quantization_config_8a4w_per_channel = get_ptq_per_channel_quant_config(
309+
act_dtype=torch.uint8,
310+
weight_dtype=torch.int4,
311+
act_observer=MinMaxObserver,
312+
act_symmetric=True,
313+
)
300314
while isinstance(node, Node) and node.op == "call_function":
301315
if node.target in [
302316
torch.ops.aten.permute.default,
@@ -334,12 +348,19 @@ def annotate_matmul_input1(node: Node):
334348
print(f"The node ({node}) is not expected in the input1 of the matmul")
335349
node = node.args[0]
336350

337-
quantization_config_16a8w = get_16a8w_qnn_ptq_config(act_observer=MinMaxObserver)
351+
if is_qat:
352+
quantization_config_16a8w = get_16a8w_qnn_qat_config(
353+
act_observer=MinMaxObserver
354+
)
355+
else:
356+
quantization_config_16a8w = get_16a8w_qnn_ptq_config(
357+
act_observer=MinMaxObserver
358+
)
338359

339360
for node in gm.graph.nodes:
340361
if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
341362
annotate_matmul(node, quantization_config_16a8w)
342-
annotate_matmul_input1(node.args[1])
363+
annotate_matmul_input1(node.args[1], is_qat=is_qat)
343364

344365

345366
def custom_annotate_llama_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901

backends/qualcomm/quantizer/qconfig.py

Lines changed: 107 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,65 @@ def get_16a8w_qnn_ptq_config(
187187
return quantization_config
188188

189189

190+
def get_16a8w_qnn_qat_config(
191+
act_observer=MovingAverageMinMaxObserver,
192+
) -> QuantizationConfig:
193+
extra_args: Dict[str, Any] = {"eps": 2**-20}
194+
act_fake_quant_ctr = FakeQuantize.with_args(
195+
dtype=torch.int32,
196+
quant_min=torch.iinfo(torch.uint16).min,
197+
quant_max=torch.iinfo(torch.uint16).max,
198+
qscheme=torch.per_tensor_affine,
199+
reduce_range=True,
200+
observer=act_observer.with_args(**extra_args),
201+
)
202+
act_quantization_spec = QuantizationSpec(
203+
dtype=torch.int32,
204+
quant_min=torch.iinfo(torch.uint16).min,
205+
quant_max=torch.iinfo(torch.uint16).max,
206+
qscheme=torch.per_tensor_affine,
207+
observer_or_fake_quant_ctr=act_fake_quant_ctr,
208+
)
209+
weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
210+
dtype=torch.int8,
211+
quant_min=torch.iinfo(torch.int8).min + 1,
212+
quant_max=torch.iinfo(torch.int8).max,
213+
qscheme=torch.per_tensor_symmetric,
214+
reduce_range=True,
215+
observer=MovingAverageMinMaxObserver,
216+
)
217+
weight_quantization_spec = QuantizationSpec(
218+
dtype=torch.int8,
219+
quant_min=torch.iinfo(torch.int8).min + 1,
220+
quant_max=torch.iinfo(torch.int8).max,
221+
qscheme=torch.per_tensor_symmetric,
222+
ch_axis=0,
223+
observer_or_fake_quant_ctr=weight_fake_quant_ctr,
224+
)
225+
bias_fake_quant_ctr = FakeQuantize.with_args(
226+
dtype=torch.int32,
227+
quant_min=torch.iinfo(torch.int32).min,
228+
quant_max=torch.iinfo(torch.int32).max,
229+
qscheme=torch.per_tensor_symmetric,
230+
observer=MovingAverageMinMaxObserver,
231+
)
232+
bias_quantization_spec = QuantizationSpec(
233+
dtype=torch.int32,
234+
quant_min=torch.iinfo(torch.int32).min,
235+
quant_max=torch.iinfo(torch.int32).max,
236+
qscheme=torch.per_tensor_symmetric,
237+
observer_or_fake_quant_ctr=bias_fake_quant_ctr,
238+
)
239+
quantization_config = QuantizationConfig(
240+
input_activation=act_quantization_spec,
241+
output_activation=act_quantization_spec,
242+
weight=weight_quantization_spec,
243+
bias=bias_quantization_spec,
244+
)
245+
246+
return quantization_config
247+
248+
190249
def get_16a16w_qnn_ptq_config(
191250
act_observer=MovingAverageMinMaxObserver,
192251
) -> QuantizationConfig:
@@ -459,6 +518,7 @@ def get_qat_per_channel_quant_config(
459518
act_dtype=torch.uint8,
460519
weight_dtype=torch.int8,
461520
act_observer=MovingAverageMinMaxObserver,
521+
act_symmetric=False,
462522
) -> QuantizationConfig:
463523
supported_act_types = {
464524
torch.uint8,
@@ -476,21 +536,38 @@ def get_qat_per_channel_quant_config(
476536
), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}"
477537

478538
# torch does not support uint16 quantization, use int32 to bypass
479-
act_fake_quant_ctr = FakeQuantize.with_args(
480-
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
481-
quant_min=torch.iinfo(act_dtype).min,
482-
quant_max=torch.iinfo(act_dtype).max,
483-
qscheme=torch.per_tensor_affine,
484-
reduce_range=True,
485-
observer=act_observer,
486-
)
487-
act_quantization_spec = QuantizationSpec(
488-
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
489-
quant_min=torch.iinfo(act_dtype).min,
490-
quant_max=torch.iinfo(act_dtype).max,
491-
qscheme=torch.per_tensor_affine,
492-
observer_or_fake_quant_ctr=act_fake_quant_ctr,
493-
)
539+
if act_symmetric:
540+
# If zero_point is 128, htp can do optimizations.
541+
# If we keep quant_min and quant_max none, observer will default use 128 as zero_point.
542+
# If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired.
543+
act_fake_quant_ctr = FakeQuantize.with_args(
544+
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
545+
qscheme=torch.per_tensor_symmetric,
546+
reduce_range=True,
547+
observer=act_observer,
548+
)
549+
act_quantization_spec = QuantizationSpec(
550+
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
551+
qscheme=torch.per_tensor_symmetric,
552+
ch_axis=0,
553+
observer_or_fake_quant_ctr=act_fake_quant_ctr,
554+
)
555+
else:
556+
act_fake_quant_ctr = FakeQuantize.with_args(
557+
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
558+
quant_min=torch.iinfo(act_dtype).min,
559+
quant_max=torch.iinfo(act_dtype).max,
560+
qscheme=torch.per_tensor_affine,
561+
reduce_range=True,
562+
observer=act_observer,
563+
)
564+
act_quantization_spec = QuantizationSpec(
565+
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
566+
quant_min=torch.iinfo(act_dtype).min,
567+
quant_max=torch.iinfo(act_dtype).max,
568+
qscheme=torch.per_tensor_affine,
569+
observer_or_fake_quant_ctr=act_fake_quant_ctr,
570+
)
494571

495572
weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
496573
dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype,
@@ -513,7 +590,21 @@ def get_qat_per_channel_quant_config(
513590
observer_or_fake_quant_ctr=weight_fake_quant_ctr,
514591
)
515592

516-
bias_quantization_spec = _derived_bias_quant_spec
593+
bias_fake_quant_ctr = FakeQuantize.with_args(
594+
dtype=torch.int32,
595+
quant_min=torch.iinfo(torch.int32).min,
596+
quant_max=torch.iinfo(torch.int32).max,
597+
qscheme=torch.per_tensor_symmetric,
598+
reduce_range=True,
599+
observer=MovingAverageMinMaxObserver,
600+
)
601+
bias_quantization_spec = QuantizationSpec(
602+
dtype=torch.int32,
603+
quant_min=torch.iinfo(torch.int32).min,
604+
quant_max=torch.iinfo(torch.int32).max,
605+
qscheme=torch.per_tensor_symmetric,
606+
observer_or_fake_quant_ctr=bias_fake_quant_ctr,
607+
)
517608

518609
quantization_config = QuantizationConfig(
519610
input_activation=act_quantization_spec,

backends/qualcomm/quantizer/quantizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
get_16a4w_qnn_ptq_config,
2424
get_16a4w_qnn_qat_config,
2525
get_16a8w_qnn_ptq_config,
26+
get_16a8w_qnn_qat_config,
2627
get_8a8w_qnn_ptq_config,
2728
get_8a8w_qnn_qat_config,
2829
get_ptq_per_block_quant_config,
@@ -39,6 +40,7 @@
3940
"QuantDtype",
4041
"get_16a4w_qnn_ptq_config",
4142
"get_16a8w_qnn_ptq_config",
43+
"get_16a8w_qnn_qat_config",
4244
"get_16a16w_qnn_ptq_config",
4345
"get_8a8w_qnn_ptq_config",
4446
"get_8a8w_qnn_qat_config",

0 commit comments

Comments
 (0)