Skip to content

Commit 0e9d871

Browse files
zingooscarandersson8218digantdesai
authored
Arm backend: Update weight observer for QAT (#14115)
Co-authored-by: Oscar Andersson <[email protected]> Co-authored-by: Digant Desai <[email protected]>
1 parent 72d50b2 commit 0e9d871

File tree

2 files changed

+115
-5
lines changed

2 files changed

+115
-5
lines changed

backends/arm/quantizer/arm_quantizer.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,27 @@ def get_symmetric_quantization_config(
105105
# Determine the right observer/fake-quant constructor
106106
if is_qat:
107107
if is_per_channel:
108-
weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver
108+
weight_observer_or_fake_quant_ctr = FakeQuantize.with_args(
109+
observer=PerChannelMinMaxObserver,
110+
quant_min=weight_qmin,
111+
quant_max=weight_qmax,
112+
dtype=torch.qint8,
113+
qscheme=torch.per_channel_symmetric,
114+
reduce_range=False,
115+
ch_axis=0,
116+
**extra_args,
117+
)
109118
else:
110119
# Set plain fake-quant with true min/max
111-
weight_observer_or_fake_quant_ctr = FakeQuantize
120+
weight_observer_or_fake_quant_ctr = FakeQuantize.with_args(**extra_args)
112121
else:
113122
# PTQ: set min/max observer
114123
weight_observer_or_fake_quant_ctr = (
115124
PerChannelMinMaxObserver if is_per_channel else MinMaxObserver
116125
)
126+
weight_observer_or_fake_quant_ctr = weight_observer_or_fake_quant_ctr.with_args(
127+
**extra_args,
128+
)
117129

118130
weight_quantization_spec = QuantizationSpec(
119131
dtype=torch.int8,
@@ -122,9 +134,7 @@ def get_symmetric_quantization_config(
122134
qscheme=weight_qscheme,
123135
ch_axis=0,
124136
is_dynamic=False,
125-
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
126-
**extra_args
127-
),
137+
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr,
128138
)
129139

130140
bias_quantization_spec = None
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import logging
7+
8+
import torch
9+
from executorch.backends.arm.quantizer import (
10+
get_symmetric_quantization_config,
11+
TOSAQuantizer,
12+
)
13+
14+
from executorch.backends.arm.tosa.specification import TosaSpecification
15+
from torch.export import export
16+
from torchao.quantization.pt2e import (
17+
move_exported_model_to_eval,
18+
move_exported_model_to_train,
19+
)
20+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_qat_pt2e
21+
22+
logger = logging.getLogger(__name__)
23+
24+
25+
class MLP(torch.nn.Module):
26+
def __init__(self):
27+
super().__init__()
28+
self.sequential = torch.nn.Sequential(
29+
torch.nn.Linear(1, 10),
30+
torch.nn.ReLU(),
31+
torch.nn.Linear(10, 10),
32+
torch.nn.ReLU(),
33+
torch.nn.Linear(10, 1),
34+
)
35+
36+
def forward(self, x):
37+
return self.sequential(x)
38+
39+
40+
def evaluate_model(model, inputs, expected_outputs):
41+
with torch.no_grad():
42+
test_outputs = model(inputs)
43+
loss = torch.nn.functional.mse_loss(test_outputs, expected_outputs)
44+
logger.info(f"Mean squared error: {loss.item()}")
45+
46+
47+
def test_qat_training_loop():
48+
"""Test the QAT training loop with a simple MLP model.
49+
This function creates a simple MLP model, prepares it for QAT, runs a training loop,
50+
and evaluates the quantized model to make sure everything works as expected."""
51+
52+
model = MLP()
53+
logger.info("Starting training loop test")
54+
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
55+
for epoch in range(100):
56+
model.train()
57+
optimizer.zero_grad()
58+
inputs = torch.randn(100, 1).clamp(-1, 1)
59+
outputs = model(inputs)
60+
loss = torch.nn.functional.mse_loss(outputs, torch.sin(inputs))
61+
loss.backward()
62+
optimizer.step()
63+
if epoch % 5 == 0:
64+
logger.info(f"Epoch {epoch}, Loss: {loss.item()}")
65+
logger.info("Training loop test completed successfully")
66+
67+
logger.info("Evaluating model before QAT")
68+
test_inputs = torch.randn(20, 1).clamp(-1, 1)
69+
test_outputs = torch.sin(test_inputs)
70+
evaluate_model(model, test_inputs, test_outputs)
71+
72+
exported_model = export(model, (torch.randn(1, 1),), strict=True)
73+
quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT"))
74+
quantizer.set_global(get_symmetric_quantization_config(is_qat=True))
75+
76+
prepared_model = prepare_qat_pt2e(exported_model.module(), quantizer)
77+
prepared_model = move_exported_model_to_train(prepared_model)
78+
logger.info("QAT model prepared successfully")
79+
80+
logger.info("Starting QAT training loop")
81+
82+
for epoch in range(25):
83+
inputs = torch.randn(100, 1).clamp(-1, 1)
84+
optimizer.zero_grad()
85+
outputs = prepared_model(inputs)
86+
loss = torch.nn.functional.mse_loss(outputs, torch.sin(inputs))
87+
loss.backward()
88+
optimizer.step()
89+
if epoch % 5 == 0:
90+
logger.info(f"QAT Epoch {epoch}, Loss: {loss.item()}")
91+
logger.info("QAT training loop completed successfully")
92+
prepared_model = move_exported_model_to_eval(prepared_model)
93+
94+
quantized_model = convert_pt2e(prepared_model)
95+
logger.info("QAT model quantized successfully")
96+
97+
logger.info("Evaluating quantized model")
98+
test_inputs = torch.randn(100, 1).clamp(-1, 1)
99+
test_outputs = torch.sin(test_inputs)
100+
evaluate_model(quantized_model, test_inputs, test_outputs)

0 commit comments

Comments
 (0)