Skip to content

Commit 4884bc5

Browse files
mcr229facebook-github-bot
authored andcommitted
[Quantizer][XNNPACK] Fix ReLU fusion when conv/linear has > 1 user (pytorch#140846)
Summary: X-link: pytorch/executorch#6894 Bug in quantizer when Conv + ReLU is fused even when the preceeding conv has more than one user. Conv and ReLU can not be fused in this case because the result of Conv must be used elsewhere. XNNPACK Delegate naturally handles this by inserting a clamp node for ReLU. Test Plan: CI Reviewed By: digantdesai Differential Revision: D65989599
1 parent 081c168 commit 4884bc5

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,10 @@ def _annotate_linear_relu(
247247
continue
248248

249249
linear_node = maybe_linear_node
250+
if len(linear_node.users) > 1:
251+
# if linear node has multiple users, then it can't be fused with relu
252+
continue
253+
250254
input_qspec_map = {}
251255
input_act = linear_node.args[0]
252256
assert isinstance(input_act, Node)
@@ -351,6 +355,11 @@ def _do_annotate_conv_relu(
351355
continue
352356
conv_node = maybe_conv_node
353357

358+
if len(conv_node.users) > 1:
359+
# relu shouldn't be fuseable to conv if there are other users
360+
# of convolution
361+
continue
362+
354363
input_qspec_map = {}
355364
input_act = conv_node.args[0]
356365
assert isinstance(input_act, Node)
@@ -738,6 +747,12 @@ def _annotate_add_relu(
738747
continue
739748

740749
add_node = maybe_add
750+
751+
if len(add_node.users) > 1:
752+
# add can't be fused with ReLU if the result of add is being used
753+
# else where in the graph
754+
continue
755+
741756
partition = [relu_node, add_node]
742757

743758
if _is_annotated(partition):
@@ -860,6 +875,11 @@ def _annotate_mul_relu(
860875
continue
861876

862877
mul_node = maybe_mul
878+
if len(mul_node.users) > 1:
879+
# mul can't be fused with ReLU if the result of mul is being used
880+
# else where in the graph
881+
continue
882+
863883
partition = [relu_node, mul_node]
864884

865885
if _is_annotated(partition):
@@ -1003,6 +1023,7 @@ def _annotate_cat(
10031023

10041024
def _is_share_obs_or_fq_op(op: Callable) -> bool:
10051025
return op in [
1026+
torch.ops.aten.relu.default,
10061027
torch.ops.aten.hardtanh.default,
10071028
torch.ops.aten.hardtanh_.default,
10081029
torch.ops.aten.max_pool2d.default,

0 commit comments

Comments
 (0)