@@ -43,6 +43,28 @@ def forward(self, x):
43
43
return self .relu (x )
44
44
45
45
46
+ test_data_conv_relu = {
47
+ # (test_name, test_data)
48
+ "4d_randn_inplace=True" : (lambda : (torch .randn (1 , 64 , 96 , 96 ) * 1000 , True )),
49
+ "4d_randn_inplace=False" : (lambda : (torch .randn (1 , 64 , 96 , 96 ) * 1000 , False )),
50
+ }
51
+
52
+
53
+ class Conv2d_Relu_Add (torch .nn .Module ):
54
+ def __init__ (self , inplace : bool = True ):
55
+ super ().__init__ ()
56
+ self .conv1 = torch .nn .Conv2d (
57
+ in_channels = 64 , out_channels = 64 , kernel_size = 7 , padding = "same"
58
+ )
59
+ self .relu = torch .nn .ReLU (inplace = inplace )
60
+
61
+ def forward (self , x : torch .Tensor ):
62
+ y = self .conv1 (x )
63
+ z = self .relu (y )
64
+ out = x + z
65
+ return out
66
+
67
+
46
68
@common .parametrize ("test_data" , test_data_suite )
47
69
def test_relu_tosa_FP (test_data : torch .Tensor ):
48
70
pipeline = TosaPipelineFP [input_t1 ](
@@ -54,6 +76,35 @@ def test_relu_tosa_FP(test_data: torch.Tensor):
54
76
pipeline .run ()
55
77
56
78
79
+ # Test the folding of Conv2D with ReLU
80
+ @common .parametrize ("test_data" , test_data_conv_relu )
81
+ def test_conv_relu_folding_tosa_INT (test_data : torch .Tensor ):
82
+ input_data , inplace = test_data ()
83
+ pipeline = TosaPipelineINT [input_t1 ](
84
+ Conv2d_Relu_Add (inplace = inplace ),
85
+ (input_data ,),
86
+ [],
87
+ [],
88
+ )
89
+ # We should have :
90
+ # 3 quantize_per_tensor nodes: input activation , output of the conv-relu sequence, out of the add
91
+ # 4 dequantize_per_tensor nodes: into the conv2d input, into the add, output of the conv-relu sequence, before returning
92
+ # 2 dequantize_per_channel nodes: one for the weights and another one for the bias
93
+ # In case of incorrect annotation of the ReLU, we get separate Q/DR around both the conv2d and the ReLU and
94
+ # therefore more quantize_per_tensor and dequantize_per_tensor nodes
95
+ pipeline .add_stage_after (
96
+ "quantize" ,
97
+ pipeline .tester .check_count ,
98
+ {
99
+ "quantized_decomposed.quantize_per_tensor.default" : 3 ,
100
+ "torch.ops.quantized_decomposed.dequantize_per_tensor.default" : 4 ,
101
+ "quantized_decomposed.dequantize_per_channel.default" : 2 ,
102
+ },
103
+ suffix = "quant_nodes" ,
104
+ )
105
+ pipeline .run ()
106
+
107
+
57
108
@common .parametrize ("test_data" , test_data_suite )
58
109
def test_relu_tosa_INT (test_data : torch .Tensor ):
59
110
pipeline = TosaPipelineINT [input_t1 ](
0 commit comments