diff --git a/backends/xnnpack/test/ops/conv2d.py b/backends/xnnpack/test/ops/conv2d.py index 95b22bb3f8a..c98743ebe8b 100644 --- a/backends/xnnpack/test/ops/conv2d.py +++ b/backends/xnnpack/test/ops/conv2d.py @@ -394,3 +394,26 @@ def get_inputs(self): quant_config=get_symmetric_quantization_config(), conv_count=2, ) + + def test_qs8_conv2d_relu_multi_users(self): + class Conv2dReluMultiUsers(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(1, 1, 1) + self.conv2 = torch.nn.Conv2d(1, 64, 1) + self.relu = torch.nn.ReLU() + + def forward(self, x): + conv_default = self.conv1(x) + y = self.relu(conv_default) + conv_default_2 = self.conv2(y) + return conv_default + conv_default_2 + + def get_inputs(self): + return (torch.randn(1, 1, 1, 1),) + + self._test( + Conv2dReluMultiUsers(), + quant_config=get_symmetric_quantization_config(), + conv_count=2, + )