Skip to content

Commit 09f5beb

Browse files
authored
Add support for fusing Conv+ReLU
Differential Revision: D79381533 Pull Request resolved: #14229
1 parent 6789f75 commit 09f5beb

File tree

4 files changed

+146
-7
lines changed

4 files changed

+146
-7
lines changed

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,19 @@
1515
BmmPattern,
1616
CatPattern,
1717
Conv1dPattern,
18+
Conv1dReluPattern0,
19+
Conv1dReluPattern1,
1820
Conv2dPattern,
21+
Conv2dReluPattern0,
22+
Conv2dReluPattern1,
1923
LayerNormPattern,
2024
LinearPattern,
2125
MatmulPattern,
2226
ReluPattern0,
2327
ReluPattern1,
2428
)
2529
from executorch.backends.cadence.aot.quantizer.utils import (
30+
check_out_zero_point_is_min_range,
2631
create_zero_bias_int32,
2732
find_sequential_partitions_aten,
2833
get_conv_args,
@@ -41,6 +46,13 @@
4146

4247
# Use this part for patterns with multiple aten ops
4348
ReluPatterns = (ReluPattern0, ReluPattern1)
49+
ConvPatterns = (Conv1dPattern, Conv2dPattern)
50+
ConvReluPatterns = (
51+
Conv1dReluPattern0,
52+
Conv1dReluPattern1,
53+
Conv2dReluPattern0,
54+
Conv2dReluPattern1,
55+
)
4456

4557

4658
def get_args_and_kwargs_add(
@@ -432,12 +444,12 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
432444
other_inputs = [node.args[idx] for node, idx in anchors.others]
433445

434446
# The node is the first index of the list and first of the tuple
435-
op_node = anchors.output[0][0]
447+
anchor_output_node = anchors.output[0][0]
436448

437-
assert len(op_node.users) == 1
438-
quant_node = list(op_node.users.keys())[0]
449+
assert len(anchor_output_node.users) == 1
450+
quant_node = list(anchor_output_node.users.keys())[0]
439451

440-
with graph_module.graph.inserting_after(op_node):
452+
with graph_module.graph.inserting_after(anchor_output_node):
441453
args = tuple(
442454
inputs_inputs + weights_inputs + other_inputs + bias_inputs
443455
)
@@ -451,9 +463,29 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
451463
)
452464
elif isinstance(pattern, CatPattern):
453465
args, kwargs = get_args_and_kwargs_cat(
454-
inputs_inputs, other_inputs, op_node
466+
inputs_inputs, other_inputs, anchor_output_node
467+
)
468+
elif isinstance(pattern, ConvReluPatterns):
469+
# For ConvReLU, we are fusing Conv+ReLU
470+
# This means that the op we want to get
471+
# the replacement args and kwargs for is the
472+
# *conv* op, which is the anchor input, NOT
473+
# the anchor output (which is the ReLU)
474+
check_out_zero_point_is_min_range(
475+
quant_node.args[2], quant_node.args[5]
476+
)
477+
anchor_input_node = anchors.inputs[0][0]
478+
args, kwargs = get_args_and_kwargs_conv(
479+
graph_module,
480+
inputs_inputs,
481+
dequants_inputs,
482+
weights_inputs,
483+
dequants_weights,
484+
bias_inputs,
485+
quant_node,
486+
anchor_input_node,
455487
)
456-
elif isinstance(pattern, (Conv1dPattern, Conv2dPattern)):
488+
elif isinstance(pattern, ConvPatterns):
457489
args, kwargs = get_args_and_kwargs_conv(
458490
graph_module,
459491
inputs_inputs,
@@ -462,7 +494,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
462494
dequants_weights,
463495
bias_inputs,
464496
quant_node,
465-
op_node,
497+
anchor_output_node,
466498
)
467499
elif isinstance(pattern, LinearPattern):
468500
args, kwargs = get_args_and_kwargs_linear(

backends/cadence/aot/quantizer/patterns.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,3 +417,71 @@ def partition_types(self) -> List[OpOverload]:
417417
class ReluPattern1(ReluBasePattern):
418418
def partition_types(self) -> List[OpOverload]:
419419
return [torch.ops.aten.relu_.default]
420+
421+
422+
# This is a base class for Conv+ReLU fusion, since it can be used with two different relu aten ops
423+
class ConvReluBasePattern(QuantizationPattern):
424+
@abstractmethod
425+
def partition_types(self) -> List[OpOverload]:
426+
pass
427+
428+
def get_anchors(
429+
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
430+
) -> PartitionAnchors:
431+
# The first node should be conv, the second should be relu
432+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
433+
conv_node = fused_partition[0].nodes[-1] # Second to last node
434+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
435+
relu_node = fused_partition[1].nodes[-1] # Last node
436+
437+
bias_qspec = DerivedQuantizationSpec(
438+
derived_from=[
439+
(conv_node.args[0], conv_node),
440+
(conv_node.args[1], conv_node),
441+
],
442+
derive_qparams_fn=get_bias_qparams,
443+
dtype=torch.int32,
444+
quant_min=-(2**31),
445+
quant_max=2**31 - 1,
446+
qscheme=torch.per_tensor_affine,
447+
)
448+
449+
# Keep bias empty if not supplied
450+
bias = []
451+
if len(conv_node.args) > 2 and conv_node.args[2] is not None:
452+
bias = [(conv_node, 2, bias_qspec)]
453+
454+
return PartitionAnchors(
455+
inputs=[(conv_node, 0)],
456+
weights=[(conv_node, 1)],
457+
# pyre-fixme[6]: Incompatible parameter type
458+
biases=bias,
459+
output=[(relu_node,)], # Output is from the relu node
460+
)
461+
462+
def replacement_op(self) -> OpOverload:
463+
return torch.ops.cadence.quantized_conv_nchw.default
464+
465+
466+
# Conv1d + regular relu op fusion
467+
class Conv1dReluPattern0(ConvReluBasePattern):
468+
def partition_types(self) -> List[OpOverload]:
469+
return [torch.ops.aten.conv1d.default, torch.ops.aten.relu.default]
470+
471+
472+
# Conv1d + alternate relu op fusion
473+
class Conv1dReluPattern1(ConvReluBasePattern):
474+
def partition_types(self) -> List[OpOverload]:
475+
return [torch.ops.aten.conv1d.default, torch.ops.aten.relu_.default]
476+
477+
478+
# Conv2d + regular relu op fusion
479+
class Conv2dReluPattern0(ConvReluBasePattern):
480+
def partition_types(self) -> List[OpOverload]:
481+
return [torch.ops.aten.conv2d.default, torch.ops.aten.relu.default]
482+
483+
484+
# Conv2d + alternate relu op fusion
485+
class Conv2dReluPattern1(ConvReluBasePattern):
486+
def partition_types(self) -> List[OpOverload]:
487+
return [torch.ops.aten.conv2d.default, torch.ops.aten.relu_.default]

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
BmmPattern,
1717
CatPattern,
1818
Conv1dPattern,
19+
Conv1dReluPattern0,
20+
Conv1dReluPattern1,
1921
Conv2dPattern,
22+
Conv2dReluPattern0,
23+
Conv2dReluPattern1,
2024
LayerNormPattern,
2125
LinearPattern,
2226
MatmulPattern,
@@ -260,3 +264,22 @@ def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
260264
quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8W8))
261265
quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8W8))
262266
super().__init__(quantizers)
267+
268+
269+
class CadenceFusedConvReluQuantizer(CadenceQuantizer):
270+
"""
271+
Quantizer using fused conv+relu patterns, and including add and cat
272+
"""
273+
274+
def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
275+
if quantizers is None:
276+
quantizers = []
277+
# Order matters here, perform the "fused" patterns first
278+
quantizers.append(CadenceAtenQuantizer(Conv1dReluPattern0(), qconfig_A8W8sym))
279+
quantizers.append(CadenceAtenQuantizer(Conv1dReluPattern1(), qconfig_A8W8sym))
280+
quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern0(), qconfig_A8W8sym))
281+
quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern1(), qconfig_A8W8sym))
282+
quantizers = quantizers + get_cadence_default_quantizers()
283+
quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8W8))
284+
quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8W8))
285+
super().__init__(quantizers)

backends/cadence/aot/quantizer/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,3 +234,19 @@ def find_sequential_partitions_aten(
234234
if _partitions_sequential(candidate):
235235
fused_partitions.append(candidate)
236236
return fused_partitions
237+
238+
239+
def check_out_zero_point_is_min_range(
240+
out_zero_point: int,
241+
out_dtype: torch.dtype,
242+
) -> bool:
243+
"""
244+
Checks if the out_zero_point is the minimum range of the quant type.
245+
"""
246+
if out_dtype == torch.int8:
247+
return out_zero_point == -128
248+
elif out_dtype == torch.int16:
249+
return out_zero_point == -32768
250+
elif out_dtype == torch.uint8 or torch.uint16:
251+
return out_zero_point == 0
252+
return False

0 commit comments

Comments
 (0)