Skip to content

Commit 10689f5

Browse files
Marco Giordanofacebook-github-bot
authored andcommitted
Adding mixed quantization support (#14134)
Summary: # Context This Diff adds support for mixed quantization operators in Executorch. Now weights and biases can be quantized, while inputs and activations are kept in floating point. # In this diff 1. Op nodes are returned from each pattern matching 2. Dequantize nodes are bypassed if not needed in the final graph. Reviewed By: skrtskrtfb Differential Revision: D81519735
1 parent dc87d22 commit 10689f5

File tree

3 files changed

+62
-50
lines changed

3 files changed

+62
-50
lines changed

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
402402
pattern.partition_types(),
403403
)
404404
for fused_partition in fused_partitions:
405-
anchors = pattern.get_anchors(graph_module, fused_partition)
405+
anchors, op_node = pattern.get_anchors(graph_module, fused_partition)
406406
if not anchors or anchors.empty:
407407
continue
408408
if any(self.is_fused(p.nodes) for p in fused_partition):
@@ -443,13 +443,17 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
443443
bias_inputs = [node.args[0] for node in dequants_biases]
444444
other_inputs = [node.args[idx] for node, idx in anchors.others]
445445

446-
# The node is the first index of the list and first of the tuple
447-
anchor_output_node = anchors.output[0][0]
448-
449-
assert len(anchor_output_node.users) == 1
450-
quant_node = list(anchor_output_node.users.keys())[0]
451-
452-
with graph_module.graph.inserting_after(anchor_output_node):
446+
# Check if there's a quantization node after the operation
447+
quant_node = None
448+
if len(anchors.output) == 1:
449+
# Check if it's actually a quantization node
450+
if hasattr(op_node, 'users') and len(op_node.users) == 1:
451+
potential_quant_node = list(op_node.users.keys())[0]
452+
if (potential_quant_node.target ==
453+
torch.ops.quantized_decomposed.quantize_per_tensor.default):
454+
quant_node = potential_quant_node
455+
456+
with graph_module.graph.inserting_after(op_node):
453457
args = tuple(
454458
inputs_inputs + weights_inputs + other_inputs + bias_inputs
455459
)
@@ -463,7 +467,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
463467
)
464468
elif isinstance(pattern, CatPattern):
465469
args, kwargs = get_args_and_kwargs_cat(
466-
inputs_inputs, other_inputs, anchor_output_node
470+
inputs_inputs, other_inputs, op_node
467471
)
468472
elif isinstance(pattern, ConvReluPatterns):
469473
# For ConvReLU, we are fusing Conv+ReLU
@@ -494,7 +498,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
494498
dequants_weights,
495499
bias_inputs,
496500
quant_node,
497-
anchor_output_node,
501+
op_node,
498502
)
499503
elif isinstance(pattern, LinearPattern):
500504
args, kwargs = get_args_and_kwargs_linear(
@@ -543,18 +547,26 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
543547
dequants_inputs,
544548
quant_node,
545549
)
550+
546551
fused = graph_module.graph.call_function(
547552
pattern.replacement_op(),
548553
args,
549554
kwargs,
550555
)
551-
fused.meta = quant_node.meta
552-
quant_node.replace_all_uses_with(fused)
556+
557+
if quant_node:
558+
fused.meta = quant_node.meta
559+
else:
560+
fused.meta = op_node.meta
561+
op_node.replace_all_uses_with(fused)
562+
if op_node.op == "output":
563+
_ = graph_module.graph.output((fused,))
553564

554565
legalize_graph(graph_module)
555566
graph_module.graph.eliminate_dead_code()
556-
# pyre-fixme[7]: Incompatible return type
557567
graph_module.recompile()
568+
return PassResult(graph_module, True)
569+
558570

559571
@classmethod
560572
# pyre-ignore[2]: Parameter `nodes` has no type specified

backends/cadence/aot/quantizer/patterns.py

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def partition_types(self) -> list[OpOverload]:
6767
@abstractmethod
6868
def get_anchors(
6969
self, gm: torch.fx.GraphModule, fused_partition: List[fx.GraphModule]
70-
) -> Optional[PartitionAnchors]:
70+
) -> Tuple[PartitionAnchors, fx.Node]:
7171
pass
7272

7373
@abstractmethod
@@ -85,7 +85,7 @@ def partition_types(self) -> List[OpOverload]:
8585

8686
def get_anchors(
8787
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
88-
) -> PartitionAnchors:
88+
) -> Tuple[PartitionAnchors, fx.Node]:
8989
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
9090
addmm_node = fused_partition[0].nodes[-1]
9191

@@ -101,12 +101,12 @@ def get_anchors(
101101
qscheme=torch.per_tensor_affine,
102102
)
103103

104-
return PartitionAnchors(
104+
return (PartitionAnchors(
105105
inputs=[(addmm_node, 1)],
106106
weights=[(addmm_node, 2)],
107107
biases=[(addmm_node, 0, bias_qspec)],
108108
output=[(addmm_node,)],
109-
)
109+
), addmm_node)
110110

111111
def replacement_op(self) -> OpOverload:
112112
return torch.ops.cadence.quantized_linear.default
@@ -118,7 +118,7 @@ def partition_types(self) -> List[OpOverload]:
118118

119119
def get_anchors(
120120
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
121-
) -> PartitionAnchors:
121+
) -> Tuple[PartitionAnchors, fx.Node]:
122122
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
123123
add_node = fused_partition[0].nodes[-1]
124124

@@ -129,16 +129,16 @@ def get_anchors(
129129
add_node.args[1], fx.Node
130130
)
131131
if not is_tensor_add or len(add_node.kwargs) > 0:
132-
return PartitionAnchors(
132+
return (PartitionAnchors(
133133
empty=True,
134-
)
134+
), add_node)
135135

136-
return PartitionAnchors(
136+
return (PartitionAnchors(
137137
inputs=[(add_node, 0), (add_node, 1)],
138138
weights=[],
139139
biases=[],
140140
output=[(add_node,)],
141-
)
141+
), add_node)
142142

143143
def replacement_op(self) -> OpOverload:
144144
return torch.ops.cadence.quantized_add.default
@@ -150,16 +150,16 @@ def partition_types(self) -> List[OpOverload]:
150150

151151
def get_anchors(
152152
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
153-
) -> PartitionAnchors:
153+
) -> Tuple[PartitionAnchors, fx.Node]:
154154
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
155155
bmm_node = fused_partition[0].nodes[-1]
156156

157-
return PartitionAnchors(
157+
return (PartitionAnchors(
158158
inputs=[(bmm_node, 0), (bmm_node, 1)],
159159
weights=[],
160160
biases=[],
161161
output=[(bmm_node,)],
162-
)
162+
), bmm_node)
163163

164164
def replacement_op(self) -> OpOverload:
165165
return torch.ops.cadence.quantized_matmul.default
@@ -171,7 +171,7 @@ def partition_types(self) -> List[OpOverload]:
171171

172172
def get_anchors(
173173
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
174-
) -> PartitionAnchors:
174+
) -> Tuple[PartitionAnchors, fx.Node]:
175175
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
176176
cat_node = fused_partition[0].nodes[-1]
177177

@@ -198,14 +198,14 @@ def get_anchors(
198198
)
199199
)
200200

201-
return PartitionAnchors(
201+
return (PartitionAnchors(
202202
inputs=args,
203203
weights=[],
204204
biases=[],
205205
output=[
206206
(cat_node, SharedQuantizationSpec((cat_node.args[0][0], cat_node)))
207207
],
208-
)
208+
), cat_node)
209209

210210
def replacement_op(self) -> OpOverload:
211211
return torch.ops.aten.cat.default
@@ -217,7 +217,7 @@ def partition_types(self) -> List[OpOverload]:
217217

218218
def get_anchors(
219219
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
220-
) -> PartitionAnchors:
220+
) -> Tuple[PartitionAnchors, fx.Node]:
221221
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
222222
conv1d_node = fused_partition[0].nodes[-1]
223223

@@ -238,13 +238,13 @@ def get_anchors(
238238
if len(conv1d_node.args) > 2 and conv1d_node.args[2] is not None:
239239
bias = [(conv1d_node, 2, bias_qspec)]
240240

241-
return PartitionAnchors(
241+
return (PartitionAnchors(
242242
inputs=[(conv1d_node, 0)],
243243
weights=[(conv1d_node, 1)],
244244
# pyre-fixme[6]: Incompatible parameter type
245245
biases=bias,
246246
output=[(conv1d_node,)],
247-
)
247+
), conv1d_node)
248248

249249
def replacement_op(self) -> OpOverload:
250250
return torch.ops.cadence.quantized_conv2d_nchw.default
@@ -256,7 +256,7 @@ def partition_types(self) -> List[OpOverload]:
256256

257257
def get_anchors(
258258
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
259-
) -> PartitionAnchors:
259+
) -> Tuple[PartitionAnchors, fx.Node]:
260260
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
261261
conv2d_node = fused_partition[0].nodes[-1]
262262

@@ -277,13 +277,13 @@ def get_anchors(
277277
if len(conv2d_node.args) > 2 and conv2d_node.args[2] is not None:
278278
bias = [(conv2d_node, 2, bias_qspec)]
279279

280-
return PartitionAnchors(
280+
return (PartitionAnchors(
281281
inputs=[(conv2d_node, 0)],
282282
weights=[(conv2d_node, 1)],
283283
# pyre-fixme[6]: Incompatible parameter type
284284
biases=bias,
285285
output=[(conv2d_node,)],
286-
)
286+
), conv2d_node)
287287

288288
def replacement_op(self) -> OpOverload:
289289
return torch.ops.cadence.quantized_conv2d_nchw.default
@@ -295,7 +295,7 @@ def partition_types(self) -> List[OpOverload]:
295295

296296
def get_anchors(
297297
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
298-
) -> PartitionAnchors:
298+
) -> Tuple[PartitionAnchors, fx.Node]:
299299
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
300300
layer_norm_node = fused_partition[0].nodes[-1]
301301

@@ -311,14 +311,14 @@ def get_anchors(
311311

312312
# Weights are used in quantized mode by our kernel, so they are
313313
# passed in as others here along with the normalized shape.
314-
return PartitionAnchors(
314+
return (PartitionAnchors(
315315
inputs=[(layer_norm_node, 0)],
316316
weights=[],
317317
biases=[],
318318
# Ordering: normalized_shape, weights, bias
319319
others=others,
320320
output=[(layer_norm_node,)],
321-
)
321+
), layer_norm_node)
322322

323323
def replacement_op(self) -> OpOverload:
324324
return torch.ops.cadence.quantized_layer_norm.default
@@ -330,7 +330,7 @@ def partition_types(self) -> List[OpOverload]:
330330

331331
def get_anchors(
332332
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
333-
) -> PartitionAnchors:
333+
) -> Tuple[PartitionAnchors, fx.Node]:
334334
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
335335
linear_node = fused_partition[0].nodes[-1]
336336

@@ -351,13 +351,13 @@ def get_anchors(
351351
if len(linear_node.args) > 2:
352352
bias = [(linear_node, 2, bias_qspec)]
353353

354-
return PartitionAnchors(
354+
return (PartitionAnchors(
355355
inputs=[(linear_node, 0)],
356356
weights=[(linear_node, 1)],
357357
# pyre-fixme[6]: Incompatible parameter type
358358
biases=bias,
359359
output=[(linear_node,)],
360-
)
360+
), linear_node)
361361

362362
def replacement_op(self) -> OpOverload:
363363
return torch.ops.cadence.quantized_linear.default
@@ -369,16 +369,16 @@ def partition_types(self) -> List[OpOverload]:
369369

370370
def get_anchors(
371371
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
372-
) -> PartitionAnchors:
372+
) -> Tuple[PartitionAnchors, fx.Node]:
373373
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
374374
matmul_node = fused_partition[0].nodes[-1]
375375

376-
return PartitionAnchors(
376+
return (PartitionAnchors(
377377
inputs=[(matmul_node, 0), (matmul_node, 1)],
378378
weights=[],
379379
biases=[],
380380
output=[(matmul_node,)],
381-
)
381+
), matmul_node)
382382

383383
def replacement_op(self) -> OpOverload:
384384
return torch.ops.cadence.quantized_matmul.default
@@ -392,16 +392,16 @@ def partition_types(self) -> List[OpOverload]:
392392

393393
def get_anchors(
394394
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
395-
) -> PartitionAnchors:
395+
) -> Tuple[PartitionAnchors, fx.Node]:
396396
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
397397
relu_node = fused_partition[0].nodes[-1]
398398

399-
return PartitionAnchors(
399+
return (PartitionAnchors(
400400
inputs=[(relu_node, 0)],
401401
weights=[],
402402
biases=[],
403403
output=[(relu_node,)],
404-
)
404+
), relu_node)
405405

406406
def replacement_op(self) -> OpOverload:
407407
return torch.ops.cadence.quantized_relu.default
@@ -427,7 +427,7 @@ def partition_types(self) -> List[OpOverload]:
427427

428428
def get_anchors(
429429
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
430-
) -> PartitionAnchors:
430+
) -> Tuple[PartitionAnchors, fx.Node]:
431431
# The first node should be conv, the second should be relu
432432
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
433433
conv_node = fused_partition[0].nodes[-1] # Second to last node
@@ -451,13 +451,13 @@ def get_anchors(
451451
if len(conv_node.args) > 2 and conv_node.args[2] is not None:
452452
bias = [(conv_node, 2, bias_qspec)]
453453

454-
return PartitionAnchors(
454+
return (PartitionAnchors(
455455
inputs=[(conv_node, 0)],
456456
weights=[(conv_node, 1)],
457457
# pyre-fixme[6]: Incompatible parameter type
458458
biases=bias,
459459
output=[(relu_node,)], # Output is from the relu node
460-
)
460+
), relu_node)
461461

462462
def replacement_op(self) -> OpOverload:
463463
return torch.ops.cadence.quantized_conv2d_nchw.default

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
116116
if not no_outside_users(fused_partition):
117117
continue
118118

119-
anchors = self.pattern.get_anchors(model, fused_partition)
119+
anchors, _ = self.pattern.get_anchors(model, fused_partition)
120120
if not anchors or anchors.empty:
121121
continue
122122
if is_annotated(

0 commit comments

Comments
 (0)