Skip to content

Commit 792d571

Browse files
Marco Giordanofacebook-github-bot
authored andcommitted
Adding mixed quantization support (pytorch#14134)
Summary: # Context This Diff adds support for mixed quantization operators in Executorch. Now weights and biases can be quantized, while inputs and activations are kept in floating point. # In this diff 1. Op nodes are returned from each pattern matching 2. Dequantize nodes are bypassed if not needed in the final graph. Reviewed By: skrtskrtfb Differential Revision: D81519735
1 parent d0f486a commit 792d571

File tree

3 files changed

+59
-53
lines changed

3 files changed

+59
-53
lines changed

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
471471
pattern.partition_types(),
472472
)
473473
for fused_partition in fused_partitions:
474-
anchors = pattern.get_anchors(graph_module, fused_partition)
474+
anchors, op_node = pattern.get_anchors(graph_module, fused_partition)
475475
if not anchors or anchors.empty:
476476
continue
477477
if any(self.is_fused(p.nodes) for p in fused_partition):
@@ -512,13 +512,10 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
512512
bias_inputs = [node.args[0] for node in dequants_biases]
513513
other_inputs = [node.args[idx] for node, idx in anchors.others]
514514

515-
# The node is the first index of the list and first of the tuple
516-
anchor_output_node = anchors.output[0][0]
515+
assert op_node is not None, "op_node is None"
516+
quant_node = list(op_node.users.keys())[0]
517517

518-
assert len(anchor_output_node.users) == 1
519-
quant_node = list(anchor_output_node.users.keys())[0]
520-
521-
with graph_module.graph.inserting_after(anchor_output_node):
518+
with graph_module.graph.inserting_after(op_node):
522519
args = tuple(
523520
inputs_inputs + weights_inputs + other_inputs + bias_inputs
524521
)
@@ -532,7 +529,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
532529
)
533530
elif isinstance(pattern, CatPattern):
534531
args, kwargs = get_args_and_kwargs_cat(
535-
inputs_inputs, other_inputs, anchor_output_node
532+
inputs_inputs, other_inputs, op_node
536533
)
537534
elif isinstance(pattern, ConvReluPatterns):
538535
# For ConvReLU, we are fusing Conv+ReLU
@@ -563,7 +560,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
563560
dequants_weights,
564561
bias_inputs,
565562
quant_node,
566-
anchor_output_node,
563+
op_node,
567564
)
568565
elif isinstance(pattern, LinearPattern):
569566
args, kwargs = get_args_and_kwargs_linear(
@@ -618,20 +615,29 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
618615
inputs_inputs,
619616
dequants_inputs,
620617
quant_node,
621-
anchor_output_node,
618+
op_node,
622619
)
620+
623621
fused = graph_module.graph.call_function(
624622
pattern.replacement_op(),
625623
args,
626624
kwargs,
627625
)
628-
fused.meta = quant_node.meta
629-
quant_node.replace_all_uses_with(fused)
626+
627+
if len(anchors.output) > 0:
628+
fused.meta = quant_node.meta
629+
quant_node.replace_all_uses_with(fused)
630+
else:
631+
fused.meta = op_node.meta
632+
op_node.replace_all_uses_with(fused)
633+
if op_node.op == "output":
634+
_ = graph_module.graph.output((fused,))
630635

631636
legalize_graph(graph_module)
632637
graph_module.graph.eliminate_dead_code()
633-
# pyre-fixme[7]: Incompatible return type
634638
graph_module.recompile()
639+
return PassResult(graph_module, True)
640+
635641

636642
@classmethod
637643
# pyre-ignore[2]: Parameter `nodes` has no type specified

backends/cadence/aot/quantizer/patterns.py

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def partition_types(self) -> list[OpOverload]:
6767
@abstractmethod
6868
def get_anchors(
6969
self, gm: torch.fx.GraphModule, fused_partition: List[fx.GraphModule]
70-
) -> Optional[PartitionAnchors]:
70+
) -> Tuple[PartitionAnchors, fx.Node]:
7171
pass
7272

7373
@abstractmethod
@@ -85,7 +85,7 @@ def partition_types(self) -> List[OpOverload]:
8585

8686
def get_anchors(
8787
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
88-
) -> PartitionAnchors:
88+
) -> Tuple[PartitionAnchors, fx.Node]:
8989
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
9090
addmm_node = fused_partition[0].nodes[-1]
9191

@@ -101,12 +101,12 @@ def get_anchors(
101101
qscheme=torch.per_tensor_affine,
102102
)
103103

104-
return PartitionAnchors(
104+
return (PartitionAnchors(
105105
inputs=[(addmm_node, 1)],
106106
weights=[(addmm_node, 2)],
107107
biases=[(addmm_node, 0, bias_qspec)],
108108
output=[(addmm_node,)],
109-
)
109+
), addmm_node)
110110

111111
def replacement_op(self) -> OpOverload:
112112
return torch.ops.cadence.quantized_linear.default
@@ -118,7 +118,7 @@ def partition_types(self) -> List[OpOverload]:
118118

119119
def get_anchors(
120120
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
121-
) -> PartitionAnchors:
121+
) -> Tuple[PartitionAnchors, fx.Node]:
122122
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
123123
add_node = fused_partition[0].nodes[-1]
124124

@@ -129,16 +129,16 @@ def get_anchors(
129129
add_node.args[1], fx.Node
130130
)
131131
if not is_tensor_add or len(add_node.kwargs) > 0:
132-
return PartitionAnchors(
132+
return (PartitionAnchors(
133133
empty=True,
134-
)
134+
), add_node)
135135

136-
return PartitionAnchors(
136+
return (PartitionAnchors(
137137
inputs=[(add_node, 0), (add_node, 1)],
138138
weights=[],
139139
biases=[],
140140
output=[(add_node,)],
141-
)
141+
), add_node)
142142

143143
def replacement_op(self) -> OpOverload:
144144
return torch.ops.cadence.quantized_add.default
@@ -150,16 +150,16 @@ def partition_types(self) -> List[OpOverload]:
150150

151151
def get_anchors(
152152
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
153-
) -> PartitionAnchors:
153+
) -> Tuple[PartitionAnchors, fx.Node]:
154154
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
155155
bmm_node = fused_partition[0].nodes[-1]
156156

157-
return PartitionAnchors(
157+
return (PartitionAnchors(
158158
inputs=[(bmm_node, 0), (bmm_node, 1)],
159159
weights=[],
160160
biases=[],
161161
output=[(bmm_node,)],
162-
)
162+
), bmm_node)
163163

164164
def replacement_op(self) -> OpOverload:
165165
return torch.ops.cadence.quantized_matmul.default
@@ -171,7 +171,7 @@ def partition_types(self) -> List[OpOverload]:
171171

172172
def get_anchors(
173173
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
174-
) -> PartitionAnchors:
174+
) -> Tuple[PartitionAnchors, fx.Node]:
175175
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
176176
cat_node = fused_partition[0].nodes[-1]
177177

@@ -198,14 +198,14 @@ def get_anchors(
198198
)
199199
)
200200

201-
return PartitionAnchors(
201+
return (PartitionAnchors(
202202
inputs=args,
203203
weights=[],
204204
biases=[],
205205
output=[
206206
(cat_node, SharedQuantizationSpec((cat_node.args[0][0], cat_node)))
207207
],
208-
)
208+
), cat_node)
209209

210210
def replacement_op(self) -> OpOverload:
211211
return torch.ops.aten.cat.default
@@ -217,7 +217,7 @@ def partition_types(self) -> List[OpOverload]:
217217

218218
def get_anchors(
219219
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
220-
) -> PartitionAnchors:
220+
) -> Tuple[PartitionAnchors, fx.Node]:
221221
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
222222
conv1d_node = fused_partition[0].nodes[-1]
223223

@@ -238,13 +238,13 @@ def get_anchors(
238238
if len(conv1d_node.args) > 2 and conv1d_node.args[2] is not None:
239239
bias = [(conv1d_node, 2, bias_qspec)]
240240

241-
return PartitionAnchors(
241+
return (PartitionAnchors(
242242
inputs=[(conv1d_node, 0)],
243243
weights=[(conv1d_node, 1)],
244244
# pyre-fixme[6]: Incompatible parameter type
245245
biases=bias,
246246
output=[(conv1d_node,)],
247-
)
247+
), conv1d_node)
248248

249249
def replacement_op(self) -> OpOverload:
250250
return torch.ops.cadence.quantized_conv2d_nchw.default
@@ -256,7 +256,7 @@ def partition_types(self) -> List[OpOverload]:
256256

257257
def get_anchors(
258258
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
259-
) -> PartitionAnchors:
259+
) -> Tuple[PartitionAnchors, fx.Node]:
260260
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
261261
conv2d_node = fused_partition[0].nodes[-1]
262262

@@ -277,13 +277,13 @@ def get_anchors(
277277
if len(conv2d_node.args) > 2 and conv2d_node.args[2] is not None:
278278
bias = [(conv2d_node, 2, bias_qspec)]
279279

280-
return PartitionAnchors(
280+
return (PartitionAnchors(
281281
inputs=[(conv2d_node, 0)],
282282
weights=[(conv2d_node, 1)],
283283
# pyre-fixme[6]: Incompatible parameter type
284284
biases=bias,
285285
output=[(conv2d_node,)],
286-
)
286+
), conv2d_node)
287287

288288
def replacement_op(self) -> OpOverload:
289289
return torch.ops.cadence.quantized_conv2d_nchw.default
@@ -295,7 +295,7 @@ def partition_types(self) -> List[OpOverload]:
295295

296296
def get_anchors(
297297
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
298-
) -> PartitionAnchors:
298+
) -> Tuple[PartitionAnchors, fx.Node]:
299299
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
300300
layer_norm_node = fused_partition[0].nodes[-1]
301301

@@ -311,14 +311,14 @@ def get_anchors(
311311

312312
# Weights are used in quantized mode by our kernel, so they are
313313
# passed in as others here along with the normalized shape.
314-
return PartitionAnchors(
314+
return (PartitionAnchors(
315315
inputs=[(layer_norm_node, 0)],
316316
weights=[],
317317
biases=[],
318318
# Ordering: normalized_shape, weights, bias
319319
others=others,
320320
output=[(layer_norm_node,)],
321-
)
321+
), layer_norm_node)
322322

323323
def replacement_op(self) -> OpOverload:
324324
return torch.ops.cadence.quantized_layer_norm.default
@@ -330,7 +330,7 @@ def partition_types(self) -> List[OpOverload]:
330330

331331
def get_anchors(
332332
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
333-
) -> PartitionAnchors:
333+
) -> Tuple[PartitionAnchors, fx.Node]:
334334
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
335335
linear_node = fused_partition[0].nodes[-1]
336336

@@ -351,13 +351,13 @@ def get_anchors(
351351
if len(linear_node.args) > 2:
352352
bias = [(linear_node, 2, bias_qspec)]
353353

354-
return PartitionAnchors(
354+
return (PartitionAnchors(
355355
inputs=[(linear_node, 0)],
356356
weights=[(linear_node, 1)],
357357
# pyre-fixme[6]: Incompatible parameter type
358358
biases=bias,
359359
output=[(linear_node,)],
360-
)
360+
), linear_node)
361361

362362
def replacement_op(self) -> OpOverload:
363363
return torch.ops.cadence.quantized_linear.default
@@ -369,16 +369,16 @@ def partition_types(self) -> List[OpOverload]:
369369

370370
def get_anchors(
371371
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
372-
) -> PartitionAnchors:
372+
) -> Tuple[PartitionAnchors, fx.Node]:
373373
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
374374
matmul_node = fused_partition[0].nodes[-1]
375375

376-
return PartitionAnchors(
376+
return (PartitionAnchors(
377377
inputs=[(matmul_node, 0), (matmul_node, 1)],
378378
weights=[],
379379
biases=[],
380380
output=[(matmul_node,)],
381-
)
381+
), matmul_node)
382382

383383
def replacement_op(self) -> OpOverload:
384384
return torch.ops.cadence.quantized_matmul.default
@@ -392,16 +392,16 @@ def partition_types(self) -> List[OpOverload]:
392392

393393
def get_anchors(
394394
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
395-
) -> PartitionAnchors:
395+
) -> Tuple[PartitionAnchors, fx.Node]:
396396
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
397397
relu_node = fused_partition[0].nodes[-1]
398398

399-
return PartitionAnchors(
399+
return (PartitionAnchors(
400400
inputs=[(relu_node, 0)],
401401
weights=[],
402402
biases=[],
403403
output=[(relu_node,)],
404-
)
404+
), relu_node)
405405

406406
def replacement_op(self) -> OpOverload:
407407
return torch.ops.cadence.quantized_relu.default
@@ -427,7 +427,7 @@ def partition_types(self) -> List[OpOverload]:
427427

428428
def get_anchors(
429429
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
430-
) -> PartitionAnchors:
430+
) -> Tuple[PartitionAnchors, fx.Node]:
431431
# The first node should be conv, the second should be relu
432432
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
433433
conv_node = fused_partition[0].nodes[-1] # Second to last node
@@ -451,13 +451,13 @@ def get_anchors(
451451
if len(conv_node.args) > 2 and conv_node.args[2] is not None:
452452
bias = [(conv_node, 2, bias_qspec)]
453453

454-
return PartitionAnchors(
454+
return (PartitionAnchors(
455455
inputs=[(conv_node, 0)],
456456
weights=[(conv_node, 1)],
457457
# pyre-fixme[6]: Incompatible parameter type
458458
biases=bias,
459459
output=[(relu_node,)], # Output is from the relu node
460-
)
460+
), relu_node)
461461

462462
def replacement_op(self) -> OpOverload:
463463
return torch.ops.cadence.quantized_conv2d_nchw.default
@@ -494,16 +494,16 @@ def partition_types(self) -> List[OpOverload]:
494494

495495
def get_anchors(
496496
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
497-
) -> PartitionAnchors:
497+
) -> Tuple[PartitionAnchors, fx.Node]:
498498
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
499499
softmax_node = fused_partition[0].nodes[-1]
500500

501-
return PartitionAnchors(
501+
return (PartitionAnchors(
502502
inputs=[(softmax_node, 0)],
503503
weights=[],
504504
biases=[],
505505
output=[(softmax_node,)],
506-
)
506+
), softmax_node)
507507

508508
def replacement_op(self) -> OpOverload:
509509
return torch.ops.cadence.quantized_softmax.default

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
133133
if not no_outside_users(fused_partition):
134134
continue
135135

136-
anchors = self.pattern.get_anchors(model, fused_partition)
136+
anchors, _ = self.pattern.get_anchors(model, fused_partition)
137137
if not anchors or anchors.empty:
138138
continue
139139
if is_annotated(

0 commit comments

Comments
 (0)