Skip to content

Commit 58a0d08

Browse files
Merge branch 'main' into fix-fp-to-int-casting-lowering
2 parents 543740c + d952326 commit 58a0d08

File tree

6 files changed

+237
-125
lines changed

6 files changed

+237
-125
lines changed

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from torch.fx import Node
2121
from torchao.quantization.pt2e import FixedQParamsObserver, MinMaxObserver
2222
from torchao.quantization.pt2e.quantizer import (
23+
annotate_input_qspec_map,
24+
annotate_output_qspec,
2325
QuantizationAnnotation,
2426
QuantizationSpec,
2527
SharedQuantizationSpec,
@@ -213,6 +215,24 @@ def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None
213215
_annotated=True,
214216
)
215217

218+
def annotate_rms_norm(node: Node, quantization_config: QuantizationConfig) -> None:
219+
act_node = node.args[0]
220+
weight_node = node.args[2]
221+
222+
# TODO current only support 16a16w
223+
annotate_input_qspec_map(
224+
node,
225+
act_node,
226+
quantization_config.input_activation,
227+
)
228+
229+
annotate_input_qspec_map(
230+
node,
231+
weight_node,
232+
quantization_config.input_activation,
233+
)
234+
annotate_output_qspec(node, quantization_config.output_activation)
235+
216236
def annotate_single_in_single_out(
217237
node: Node, quantization_config: QuantizationConfig
218238
) -> None:
@@ -287,6 +307,9 @@ def annotate_matmul_input1(node: Node):
287307
elif node.target == torch.ops.aten.flatten.using_ints:
288308
annotate_single_in_share_out(node, quantization_config_8a8w)
289309
node = node.args[0]
310+
elif node.target == torch.ops.aten.rms_norm.default:
311+
annotate_rms_norm(node, quantization_config_8a8w)
312+
node = node.args[0]
290313
elif node.target == torch.ops.aten.cat.default:
291314
annotate_cat(node, quantization_config_8a8w)
292315
# For v, we tag 8a until conv op.

examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@
2020
annotate_matmul_16a8w,
2121
)
2222

23+
from executorch.backends.qualcomm.quantizer.observers.per_channel_param_observer import (
24+
PerChannelParamObserver,
25+
)
26+
from executorch.backends.qualcomm.quantizer.qconfig import (
27+
_derived_bias_quant_spec,
28+
QuantizationConfig,
29+
)
30+
2331
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
2432
from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d
2533

@@ -47,6 +55,8 @@
4755

4856
from torchao.quantization.pt2e import MinMaxObserver
4957
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
58+
from torchao.quantization.pt2e.quantizer import QuantizationSpec
59+
5060

5161
sys.setrecursionlimit(4096)
5262
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
@@ -78,6 +88,33 @@ def forward(
7888
return self.model.forward(tokens, self.atten_mask)
7989

8090

91+
def add_mse_weight_observer(quant_dtype, quantizer):
92+
weight_dtype = (
93+
torch.int4
94+
if quant_dtype in (QuantDtype.use_16a4w, QuantDtype.use_16a4w_block)
95+
else torch.int8
96+
)
97+
per_channel_q_config = quantizer.default_quant_config.quant_config
98+
weight_qspec = QuantizationSpec(
99+
dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype,
100+
quant_min=(
101+
-7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).min + 1
102+
),
103+
quant_max=(7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max),
104+
qscheme=torch.per_channel_symmetric,
105+
ch_axis=0,
106+
observer_or_fake_quant_ctr=PerChannelParamObserver.with_args(
107+
**{"steps": 200, "use_mse": True}
108+
),
109+
)
110+
quantizer.default_quant_config.per_channel_quant_config = QuantizationConfig(
111+
input_activation=per_channel_q_config.input_activation,
112+
output_activation=per_channel_q_config.output_activation,
113+
weight=weight_qspec,
114+
bias=_derived_bias_quant_spec,
115+
)
116+
117+
81118
def gen_eval_wrapper(model_name, args):
82119
tokenizer = get_tokenizer(args.tokenizer_path)
83120
with open(args.params) as f:
@@ -142,13 +179,13 @@ def permute(w, heads):
142179
if getattr(layer.feed_forward, "prepare_feedfoward_conv", None):
143180
layer.feed_forward.prepare_feedfoward_conv()
144181

145-
model.to(dtype=torch.bfloat16)
182+
model.to(dtype=torch.float)
146183
model.to(device=args.device)
147184

148185
tokens, atten_mask = model.get_example_inputs(use_kv_cache=False)
149186
tokens = tokens.to(device=args.device)
150187
atten_mask = atten_mask.to(device=args.device)
151-
atten_mask = atten_mask.to(dtype=torch.bfloat16)
188+
atten_mask = atten_mask.to(dtype=torch.float)
152189
inputs = (tokens, atten_mask)
153190

154191
if args.embedding_quantize:
@@ -174,7 +211,8 @@ def permute(w, heads):
174211
)
175212
quantizer.add_custom_quant_annotations(custom_annotations)
176213

177-
model.has_quant_io = True
214+
if args.range_setting == "mse_weight":
215+
add_mse_weight_observer(quant_dtype, quantizer)
178216

179217
with torch.no_grad():
180218
model = torch.export.export(model, inputs, strict=True).module()
@@ -245,6 +283,23 @@ def main() -> None:
245283
torch.manual_seed(seed)
246284
modelname = "llama2"
247285
parser = build_args_parser()
286+
parser.add_argument(
287+
"-P",
288+
"--ptq",
289+
help="If specified, will do PTQ quantization. default is 16bits activation and 4bits weight. Support 8a8w, 16a4w and 16a4w_block.",
290+
type=str,
291+
)
292+
parser.add_argument(
293+
"--range_setting",
294+
help="Choose which range setting method (e.g. mse_weight). If not specified, will do minmax for weights and activations",
295+
type=str,
296+
)
297+
parser.add_argument(
298+
"--limit",
299+
help="the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples",
300+
type=str,
301+
)
302+
248303
args = parser.parse_args()
249304
args.llama_model = "llama3_2"
250305
# Overrides this arg, because evaluation requires full logits.
@@ -257,15 +312,9 @@ def main() -> None:
257312
args.use_kv_cache = False
258313
args.prefill_ar_len = args.max_seq_length
259314

260-
# To do fewer samples for faster evaluation
261-
args.limit = 0.1
262-
# args.samples = {'wikitext': list(range(1))}
263-
264315
args.device = "cuda" if torch.cuda.is_available() else "cpu"
265316
torch.set_default_device(args.device)
266317

267-
args.ptq = "8a8w"
268-
269318
eval_llama(modelname, args)
270319

271320

exir/passes/remove_mixed_type_operators.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,20 @@ def call_operator(self, op, args, kwargs, meta: NodeMetadata): # noqa: C901
2323
promotion_type_allow_list = {
2424
torch.ops.aten.add.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
2525
torch.ops.aten.mul.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
26-
torch.ops.aten.div.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
26+
# The correct promotion for div depends on the mode! If there is no mode,
27+
# it's INT_TO_FLOAT, otherwise it's default.
28+
torch.ops.aten.div.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
29+
torch.ops.aten.div.Tensor_mode: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
2730
torch.ops.aten.minimum.default: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
2831
}
2932

3033
if op in promotion_type_allow_list:
3134
promotion_kind = promotion_type_allow_list[op]
35+
if (
36+
op == torch.ops.aten.div.Tensor_mode
37+
and kwargs.get("rounding_mode") is None
38+
):
39+
promotion_kind = ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
3240
else:
3341
# Not in allow list, do nothing
3442
return super().call_operator(op, args, kwargs, meta)

exir/tests/test_passes.py

Lines changed: 101 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import os
1010
import tempfile
1111
import unittest
12-
from typing import List, Optional, Tuple
12+
from typing import Callable, List, Optional, Tuple
1313

1414
import executorch.exir as exir
1515

@@ -71,6 +71,7 @@
7171
from functorch.experimental import control_flow
7272

7373
from torch import nn
74+
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
7475
from torch.export import export
7576
from torch.export.graph_signature import InputKind, InputSpec, TensorArgument
7677
from torch.fx import GraphModule, subgraph_rewriter
@@ -121,91 +122,114 @@ def foo_out(
121122
return a + 1, None
122123

123124

125+
def simple_promote_dtype(
126+
dtype: torch.dtype, promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND
127+
) -> torch.dtype:
128+
if promotion_kind == ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT:
129+
return dtype
130+
if promotion_kind == ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT:
131+
return dtype if dtype.is_floating_point else torch.float
132+
else:
133+
raise Exception(f"Unsupported promotion kind {promotion_kind}")
134+
135+
136+
def count_nodes_with_target_asserting_arguments_have_dtype(
137+
self, module, target, arg_dtype
138+
) -> int:
139+
count = 0
140+
for node in module.graph.nodes:
141+
if node.op == "call_function" and node.target == target:
142+
count += 1
143+
for arg in node.args:
144+
self.assertEqual(arg.meta["val"].dtype, arg_dtype)
145+
return count
146+
147+
124148
class TestPasses(unittest.TestCase):
125149
@classmethod
126150
def setUpClass(cls) -> None:
127151
register_additional_test_aten_ops()
128152

129153
def test_remove_mixed_type_operators(self) -> None:
130-
class Add(torch.nn.Module):
131-
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
132-
return (x + y) + x
133-
134-
add = Add()
135-
136-
int_tensor = torch.tensor([[1, 2, 3]])
137-
float_tensor = torch.tensor([[1.0, 2.0, 3.0]])
138-
edge_prog = to_edge(export(add, (int_tensor, float_tensor), strict=True))
139-
140-
new_prog = edge_prog.transform([RemoveMixedTypeOperators()])
141-
new_graph_module = new_prog.exported_program().graph_module
142-
self.assertIsNotNone(new_graph_module)
143-
144-
add_count = 0
145-
146-
for node in new_graph_module.graph.nodes:
147-
if (
148-
node.op == "call_function"
149-
and node.target == exir_ops.edge.aten.add.Tensor
150-
):
151-
add_count += 1
152-
node_args = node.args
153-
for arg in node_args:
154-
self.assertEqual(arg.meta["val"].dtype, torch.float)
155-
156-
self.assertEqual(add_count, 2)
157-
158-
double_tensor = torch.tensor([[1.0, 2.0, 3.0]])
159-
double_tensor = double_tensor.to(torch.double)
160-
161-
double_prog = to_edge(export(add, (int_tensor, double_tensor), strict=True))
162-
163-
double_prog.transform([RemoveMixedTypeOperators()])
164-
new_graph_module_double = double_prog.exported_program().graph_module
165-
self.assertIsNotNone(new_graph_module_double)
166-
167-
add_count_double = 0
168-
169-
for node in new_graph_module_double.graph.nodes:
170-
if (
171-
node.op == "call_function"
172-
and node.target == exir_ops.edge.aten.add.Tensor
173-
):
174-
add_count_double += 1
175-
node_args = node.args
176-
for arg in node_args:
177-
self.assertEqual(arg.meta["val"].dtype, torch.double)
178-
179-
self.assertEqual(add_count_double, 2)
180-
181-
class Mult(torch.nn.Module):
182-
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
183-
return x * y
184-
185-
mult = Mult()
186-
187-
float_tensor_vert = float_tensor.T
188-
mult_prog = to_edge(export(mult, (int_tensor, float_tensor_vert), strict=True))
189-
190-
# graph_module_mult.graph.print_tabular()
154+
def make_module(fwd: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]):
155+
class Module(torch.nn.Module):
156+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
157+
return fwd(x, y)
158+
159+
return Module
160+
161+
Add = make_module(lambda x, y: (x + y) + x)
162+
Mult = make_module(lambda x, y: x * y)
163+
Minimum = make_module(torch.minimum)
164+
DivWithoutMode = make_module(torch.div)
165+
DivWithNoneMode = make_module(lambda x, y: torch.div(x, y, rounding_mode=None))
166+
DivWithTruncMode = make_module(
167+
lambda x, y: torch.div(x, y, rounding_mode="trunc")
168+
)
169+
DivWithFloorMode = make_module(
170+
lambda x, y: torch.div(x, y, rounding_mode="floor")
171+
)
191172

192-
mult_prog = mult_prog.transform([RemoveMixedTypeOperators()])
193-
new_graph_module_mult = mult_prog.exported_program().graph_module
194-
self.assertIsNotNone(new_graph_module_mult)
173+
for module, op, expected_count, promotion_kind in (
174+
(
175+
Add,
176+
exir_ops.edge.aten.add.Tensor,
177+
2,
178+
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
179+
),
180+
(
181+
Mult,
182+
exir_ops.edge.aten.mul.Tensor,
183+
1,
184+
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
185+
),
186+
(
187+
Minimum,
188+
exir_ops.edge.aten.minimum.default,
189+
1,
190+
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
191+
),
192+
(
193+
DivWithoutMode,
194+
exir_ops.edge.aten.div.Tensor,
195+
1,
196+
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
197+
),
198+
(
199+
DivWithNoneMode,
200+
exir_ops.edge.aten.div.Tensor_mode,
201+
1,
202+
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
203+
),
204+
(
205+
DivWithTruncMode,
206+
exir_ops.edge.aten.div.Tensor_mode,
207+
1,
208+
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
209+
),
210+
(
211+
DivWithFloorMode,
212+
exir_ops.edge.aten.div.Tensor_mode,
213+
1,
214+
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
215+
),
216+
):
217+
for second_arg_dtype in (torch.int64, torch.float, torch.double):
218+
int_tensor = torch.tensor([[1, 2, 3]], dtype=torch.int64)
219+
float_tensor = torch.tensor([[1.0, 2.0, 3.0]], dtype=second_arg_dtype)
220+
edge_prog = to_edge(
221+
export(module(), (int_tensor, float_tensor), strict=True)
222+
)
195223

196-
mult_count = 0
224+
new_prog = edge_prog.transform([RemoveMixedTypeOperators()])
225+
new_graph_module = new_prog.exported_program().graph_module
226+
self.assertIsNotNone(new_graph_module)
197227

198-
for node in new_graph_module_mult.graph.nodes:
199-
if (
200-
node.op == "call_function"
201-
and node.target == exir_ops.edge.aten.mul.Tensor
202-
):
203-
mult_count += 1
204-
node_args = node.args
205-
for arg in node_args:
206-
self.assertEqual(arg.meta["val"].dtype, torch.float)
207-
208-
self.assertEqual(mult_count, 1)
228+
promoted_type = simple_promote_dtype(second_arg_dtype, promotion_kind)
229+
count = count_nodes_with_target_asserting_arguments_have_dtype(
230+
self, new_graph_module, op, promoted_type
231+
)
232+
self.assertEqual(count, expected_count)
209233

210234
def test_remove_noop_pass(self) -> None:
211235
class Foo(torch.nn.Module):

0 commit comments

Comments
 (0)