Skip to content

Commit b918df8

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Add support for fusing Conv+ReLU (#14229)
Summary: This diff adds support for *implicitly* fusing Conv2d+ReLU We add a new pattern which will capture this sequence of events and ensure the subgraph is treated as one node during calculation of qparam. Then, during fuse, we replace this subgraph with just the conv and drop the ReLU. Reviewed By: ethansfng, hsharma35, ivayloen Differential Revision: D79381533
1 parent 233063c commit b918df8

File tree

4 files changed

+146
-7
lines changed

4 files changed

+146
-7
lines changed

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,19 @@
1515
BmmPattern,
1616
CatPattern,
1717
Conv1dPattern,
18+
Conv1dReluPattern0,
19+
Conv1dReluPattern1,
1820
Conv2dPattern,
21+
Conv2dReluPattern0,
22+
Conv2dReluPattern1,
1923
LayerNormPattern,
2024
LinearPattern,
2125
MatmulPattern,
2226
ReluPattern0,
2327
ReluPattern1,
2428
)
2529
from executorch.backends.cadence.aot.quantizer.utils import (
30+
check_out_zero_point_is_min_range,
2631
create_zero_bias_int32,
2732
find_sequential_partitions_aten,
2833
get_conv_args,
@@ -41,6 +46,13 @@
4146

4247
# Use this part for patterns with multiple aten ops
4348
ReluPatterns = (ReluPattern0, ReluPattern1)
49+
ConvPatterns = (Conv1dPattern, Conv2dPattern)
50+
ConvReluPatterns = (
51+
Conv1dReluPattern0,
52+
Conv1dReluPattern1,
53+
Conv2dReluPattern0,
54+
Conv2dReluPattern1,
55+
)
4456

4557

4658
def get_args_and_kwargs_add(
@@ -432,12 +444,12 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
432444
other_inputs = [node.args[idx] for node, idx in anchors.others]
433445

434446
# The node is the first index of the list and first of the tuple
435-
op_node = anchors.output[0][0]
447+
anchor_output_node = anchors.output[0][0]
436448

437-
assert len(op_node.users) == 1
438-
quant_node = list(op_node.users.keys())[0]
449+
assert len(anchor_output_node.users) == 1
450+
quant_node = list(anchor_output_node.users.keys())[0]
439451

440-
with graph_module.graph.inserting_after(op_node):
452+
with graph_module.graph.inserting_after(anchor_output_node):
441453
args = tuple(
442454
inputs_inputs + weights_inputs + other_inputs + bias_inputs
443455
)
@@ -451,9 +463,29 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
451463
)
452464
elif isinstance(pattern, CatPattern):
453465
args, kwargs = get_args_and_kwargs_cat(
454-
inputs_inputs, other_inputs, op_node
466+
inputs_inputs, other_inputs, anchor_output_node
467+
)
468+
elif isinstance(pattern, ConvReluPatterns):
469+
# For ConvReLU, we are fusing Conv+ReLU
470+
# This means that the op we want to get
471+
# the replacement args and kwargs for is the
472+
# *conv* op, which is the anchor input, NOT
473+
# the anchor output (which is the ReLU)
474+
check_out_zero_point_is_min_range(
475+
quant_node.args[2], quant_node.args[5]
476+
)
477+
anchor_input_node = anchors.inputs[0][0]
478+
args, kwargs = get_args_and_kwargs_conv(
479+
graph_module,
480+
inputs_inputs,
481+
dequants_inputs,
482+
weights_inputs,
483+
dequants_weights,
484+
bias_inputs,
485+
quant_node,
486+
anchor_input_node,
455487
)
456-
elif isinstance(pattern, (Conv1dPattern, Conv2dPattern)):
488+
elif isinstance(pattern, ConvPatterns):
457489
args, kwargs = get_args_and_kwargs_conv(
458490
graph_module,
459491
inputs_inputs,
@@ -462,7 +494,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
462494
dequants_weights,
463495
bias_inputs,
464496
quant_node,
465-
op_node,
497+
anchor_output_node,
466498
)
467499
elif isinstance(pattern, LinearPattern):
468500
args, kwargs = get_args_and_kwargs_linear(

backends/cadence/aot/quantizer/patterns.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,3 +417,71 @@ def partition_types(self) -> List[OpOverload]:
417417
class ReluPattern1(ReluBasePattern):
418418
def partition_types(self) -> List[OpOverload]:
419419
return [torch.ops.aten.relu_.default]
420+
421+
422+
# This is a base class for Conv+ReLU fusion, since it can be used with two different relu aten ops
423+
class ConvReluBasePattern(QuantizationPattern):
424+
@abstractmethod
425+
def partition_types(self) -> List[OpOverload]:
426+
pass
427+
428+
def get_anchors(
429+
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
430+
) -> PartitionAnchors:
431+
# The first node should be conv, the second should be relu
432+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
433+
conv_node = fused_partition[0].nodes[-1] # Second to last node
434+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
435+
relu_node = fused_partition[1].nodes[-1] # Last node
436+
437+
bias_qspec = DerivedQuantizationSpec(
438+
derived_from=[
439+
(conv_node.args[0], conv_node),
440+
(conv_node.args[1], conv_node),
441+
],
442+
derive_qparams_fn=get_bias_qparams,
443+
dtype=torch.int32,
444+
quant_min=-(2**31),
445+
quant_max=2**31 - 1,
446+
qscheme=torch.per_tensor_affine,
447+
)
448+
449+
# Keep bias empty if not supplied
450+
bias = []
451+
if len(conv_node.args) > 2 and conv_node.args[2] is not None:
452+
bias = [(conv_node, 2, bias_qspec)]
453+
454+
return PartitionAnchors(
455+
inputs=[(conv_node, 0)],
456+
weights=[(conv_node, 1)],
457+
# pyre-fixme[6]: Incompatible parameter type
458+
biases=bias,
459+
output=[(relu_node,)], # Output is from the relu node
460+
)
461+
462+
def replacement_op(self) -> OpOverload:
463+
return torch.ops.cadence.quantized_conv_nchw.default
464+
465+
466+
# Conv1d + regular relu op fusion
467+
class Conv1dReluPattern0(ConvReluBasePattern):
468+
def partition_types(self) -> List[OpOverload]:
469+
return [torch.ops.aten.conv1d.default, torch.ops.aten.relu.default]
470+
471+
472+
# Conv1d + alternate relu op fusion
473+
class Conv1dReluPattern1(ConvReluBasePattern):
474+
def partition_types(self) -> List[OpOverload]:
475+
return [torch.ops.aten.conv1d.default, torch.ops.aten.relu_.default]
476+
477+
478+
# Conv2d + regular relu op fusion
479+
class Conv2dReluPattern0(ConvReluBasePattern):
480+
def partition_types(self) -> List[OpOverload]:
481+
return [torch.ops.aten.conv2d.default, torch.ops.aten.relu.default]
482+
483+
484+
# Conv2d + alternate relu op fusion
485+
class Conv2dReluPattern1(ConvReluBasePattern):
486+
def partition_types(self) -> List[OpOverload]:
487+
return [torch.ops.aten.conv2d.default, torch.ops.aten.relu_.default]

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
BmmPattern,
1717
CatPattern,
1818
Conv1dPattern,
19+
Conv1dReluPattern0,
20+
Conv1dReluPattern1,
1921
Conv2dPattern,
22+
Conv2dReluPattern0,
23+
Conv2dReluPattern1,
2024
LayerNormPattern,
2125
LinearPattern,
2226
MatmulPattern,
@@ -260,3 +264,22 @@ def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
260264
quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8W8))
261265
quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8W8))
262266
super().__init__(quantizers)
267+
268+
269+
class CadenceFusedConvReluQuantizer(CadenceQuantizer):
270+
"""
271+
Quantizer using fused conv+relu patterns, and including add and cat
272+
"""
273+
274+
def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
275+
if quantizers is None:
276+
quantizers = []
277+
# Order matters here, perform the "fused" patterns first
278+
quantizers.append(CadenceAtenQuantizer(Conv1dReluPattern0(), qconfig_A8W8sym))
279+
quantizers.append(CadenceAtenQuantizer(Conv1dReluPattern1(), qconfig_A8W8sym))
280+
quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern0(), qconfig_A8W8sym))
281+
quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern1(), qconfig_A8W8sym))
282+
quantizers = quantizers + get_cadence_default_quantizers()
283+
quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8W8))
284+
quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8W8))
285+
super().__init__(quantizers)

backends/cadence/aot/quantizer/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,3 +234,19 @@ def find_sequential_partitions_aten(
234234
if _partitions_sequential(candidate):
235235
fused_partitions.append(candidate)
236236
return fused_partitions
237+
238+
239+
def check_out_zero_point_is_min_range(
240+
out_zero_point: int,
241+
out_dtype: torch.dtype,
242+
) -> bool:
243+
"""
244+
Checks if the out_zero_point is the minimum range of the quant type.
245+
"""
246+
if out_dtype == torch.int8:
247+
return out_zero_point == -128
248+
elif out_dtype == torch.int16:
249+
return out_zero_point == -32768
250+
elif out_dtype == torch.uint8 or torch.uint16:
251+
return out_zero_point == 0
252+
return False

0 commit comments

Comments
 (0)