@@ -43,6 +43,28 @@ def forward(self, x):
4343 return self .relu (x )
4444
4545
46+ test_data_conv_relu = {
47+ # (test_name, test_data)
48+ "4d_randn_inplace=True" : (lambda : (torch .randn (1 , 64 , 96 , 96 ) * 1000 , True )),
49+ "4d_randn_inplace=False" : (lambda : (torch .randn (1 , 64 , 96 , 96 ) * 1000 , False )),
50+ }
51+
52+
53+ class Conv2d_Relu_Add (torch .nn .Module ):
54+ def __init__ (self , inplace : bool = True ):
55+ super ().__init__ ()
56+ self .conv1 = torch .nn .Conv2d (
57+ in_channels = 64 , out_channels = 64 , kernel_size = 7 , padding = "same"
58+ )
59+ self .relu = torch .nn .ReLU (inplace = inplace )
60+
61+ def forward (self , x : torch .Tensor ):
62+ y = self .conv1 (x )
63+ z = self .relu (y )
64+ out = x + z
65+ return out
66+
67+
4668@common .parametrize ("test_data" , test_data_suite )
4769def test_relu_tosa_FP (test_data : torch .Tensor ):
4870 pipeline = TosaPipelineFP [input_t1 ](
@@ -54,6 +76,35 @@ def test_relu_tosa_FP(test_data: torch.Tensor):
5476 pipeline .run ()
5577
5678
79+ # Test the folding of Conv2D with ReLU
80+ @common .parametrize ("test_data" , test_data_conv_relu )
81+ def test_conv_relu_folding_tosa_INT (test_data : torch .Tensor ):
82+ input_data , inplace = test_data ()
83+ pipeline = TosaPipelineINT [input_t1 ](
84+ Conv2d_Relu_Add (inplace = inplace ),
85+ (input_data ,),
86+ [],
87+ [],
88+ )
89+ # We should have :
90+ # 3 quantize_per_tensor nodes: input activation , output of the conv-relu sequence, out of the add
91+ # 4 dequantize_per_tensor nodes: into the conv2d input, into the add, output of the conv-relu sequence, before returning
92+ # 2 dequantize_per_channel nodes: one for the weights and another one for the bias
93+ # In case of incorrect annotation of the ReLU, we get separate Q/DR around both the conv2d and the ReLU and
94+ # therefore more quantize_per_tensor and dequantize_per_tensor nodes
95+ pipeline .add_stage_after (
96+ "quantize" ,
97+ pipeline .tester .check_count ,
98+ {
99+ "quantized_decomposed.quantize_per_tensor.default" : 3 ,
100+ "torch.ops.quantized_decomposed.dequantize_per_tensor.default" : 4 ,
101+ "quantized_decomposed.dequantize_per_channel.default" : 2 ,
102+ },
103+ suffix = "quant_nodes" ,
104+ )
105+ pipeline .run ()
106+
107+
57108@common .parametrize ("test_data" , test_data_suite )
58109def test_relu_tosa_INT (test_data : torch .Tensor ):
59110 pipeline = TosaPipelineINT [input_t1 ](
0 commit comments