Skip to content

Commit a747e4d

Browse files
authored
Arm backend: Fix annotation of inplace ReLU (#14540)
The ResNet18 model uses a lot of ReLUs with inplace=True As a result of the correct annotation, we can pass the numerical accuracy check on resnet with lower atol.
1 parent f198c27 commit a747e4d

File tree

3 files changed

+63
-3
lines changed

3 files changed

+63
-3
lines changed

backends/arm/quantizer/quantization_annotator.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,11 @@ def any_or_hardtanh_min_zero(n: Node):
392392
torch.ops.aten.conv2d.padding,
393393
],
394394
[torch.ops.aten.batch_norm.default, F.batch_norm],
395-
[torch.ops.aten.relu.default, torch.ops.aten.hardtanh.default],
395+
[
396+
torch.ops.aten.relu.default,
397+
torch.ops.aten.relu_.default,
398+
torch.ops.aten.hardtanh.default,
399+
],
396400
],
397401
filter_fn=any_or_hardtanh_min_zero,
398402
):
@@ -408,6 +412,7 @@ def any_or_hardtanh_min_zero(n: Node):
408412
]
409413
elif node.target in (
410414
torch.ops.aten.relu.default,
415+
torch.ops.aten.relu_.default,
411416
torch.ops.aten.hardtanh.default,
412417
):
413418
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
@@ -444,7 +449,11 @@ def any_or_hardtanh_min_zero(n: Node):
444449
torch.ops.aten.linear.default,
445450
torch.ops.aten.conv2d.padding,
446451
],
447-
[torch.ops.aten.relu.default, torch.ops.aten.hardtanh.default],
452+
[
453+
torch.ops.aten.relu.default,
454+
torch.ops.aten.relu_.default,
455+
torch.ops.aten.hardtanh.default,
456+
],
448457
],
449458
any_or_hardtanh_min_zero,
450459
):

backends/arm/test/models/test_resnet18.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_resnet_tosa_INT(per_channel_quantization):
5454
exir_op=[],
5555
use_to_edge_transform_and_lower=True,
5656
per_channel_quantization=per_channel_quantization,
57-
atol=0.5,
57+
atol=0.25,
5858
qtol=1,
5959
)
6060
pipeline.run()

backends/arm/test/ops/test_relu.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,28 @@ def forward(self, x):
4343
return self.relu(x)
4444

4545

46+
test_data_conv_relu = {
47+
# (test_name, test_data)
48+
"4d_randn_inplace=True": (lambda: (torch.randn(1, 64, 96, 96) * 1000, True)),
49+
"4d_randn_inplace=False": (lambda: (torch.randn(1, 64, 96, 96) * 1000, False)),
50+
}
51+
52+
53+
class Conv2d_Relu_Add(torch.nn.Module):
54+
def __init__(self, inplace: bool = True):
55+
super().__init__()
56+
self.conv1 = torch.nn.Conv2d(
57+
in_channels=64, out_channels=64, kernel_size=7, padding="same"
58+
)
59+
self.relu = torch.nn.ReLU(inplace=inplace)
60+
61+
def forward(self, x: torch.Tensor):
62+
y = self.conv1(x)
63+
z = self.relu(y)
64+
out = x + z
65+
return out
66+
67+
4668
@common.parametrize("test_data", test_data_suite)
4769
def test_relu_tosa_FP(test_data: torch.Tensor):
4870
pipeline = TosaPipelineFP[input_t1](
@@ -54,6 +76,35 @@ def test_relu_tosa_FP(test_data: torch.Tensor):
5476
pipeline.run()
5577

5678

79+
# Test the folding of Conv2D with ReLU
80+
@common.parametrize("test_data", test_data_conv_relu)
81+
def test_conv_relu_folding_tosa_INT(test_data: torch.Tensor):
82+
input_data, inplace = test_data()
83+
pipeline = TosaPipelineINT[input_t1](
84+
Conv2d_Relu_Add(inplace=inplace),
85+
(input_data,),
86+
[],
87+
[],
88+
)
89+
# We should have :
90+
# 3 quantize_per_tensor nodes: input activation , output of the conv-relu sequence, out of the add
91+
# 4 dequantize_per_tensor nodes: into the conv2d input, into the add, output of the conv-relu sequence, before returning
92+
# 2 dequantize_per_channel nodes: one for the weights and another one for the bias
93+
# In case of incorrect annotation of the ReLU, we get separate Q/DR around both the conv2d and the ReLU and
94+
# therefore more quantize_per_tensor and dequantize_per_tensor nodes
95+
pipeline.add_stage_after(
96+
"quantize",
97+
pipeline.tester.check_count,
98+
{
99+
"quantized_decomposed.quantize_per_tensor.default": 3,
100+
"torch.ops.quantized_decomposed.dequantize_per_tensor.default": 4,
101+
"quantized_decomposed.dequantize_per_channel.default": 2,
102+
},
103+
suffix="quant_nodes",
104+
)
105+
pipeline.run()
106+
107+
57108
@common.parametrize("test_data", test_data_suite)
58109
def test_relu_tosa_INT(test_data: torch.Tensor):
59110
pipeline = TosaPipelineINT[input_t1](

0 commit comments

Comments
 (0)