Skip to content

Commit aa3cf7c

Browse files
mcr229facebook-github-bot
authored andcommitted
Fix ReLU fusion when conv/linear has > 1 user (pytorch#6894)
Summary: X-link: pytorch/pytorch#140846 Bug in quantizer when Conv + ReLU is fused even when the preceeding conv has more than one user. Conv and ReLU can not be fused in this case because the result of Conv must be used elsewhere. XNNPACK Delegate naturally handles this by inserting a clamp node for ReLU. Reviewed By: digantdesai Differential Revision: D65989599
1 parent 5b4d9bb commit aa3cf7c

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

backends/xnnpack/test/ops/conv2d.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,3 +394,26 @@ def get_inputs(self):
394394
quant_config=get_symmetric_quantization_config(),
395395
conv_count=2,
396396
)
397+
398+
def test_qs8_conv2d_relu_multi_users(self):
399+
class Conv2dReluMultiUsers(torch.nn.Module):
400+
def __init__(self):
401+
super().__init__()
402+
self.conv1 = torch.nn.Conv2d(1, 1, 1)
403+
self.conv2 = torch.nn.Conv2d(1, 64, 1)
404+
self.relu = torch.nn.ReLU()
405+
406+
def forward(self, x):
407+
conv_default = self.conv1(x)
408+
y = self.relu(conv_default)
409+
conv_default_2 = self.conv2(y)
410+
return conv_default + conv_default_2
411+
412+
def get_inputs(self):
413+
return (torch.randn(1, 1, 1, 1),)
414+
415+
self._test(
416+
Conv2dReluMultiUsers(),
417+
quant_config=get_symmetric_quantization_config(),
418+
conv_count=2,
419+
)

0 commit comments

Comments
 (0)