@@ -1040,6 +1040,116 @@ def test_quantized_conv_per_tensor(
10401040 f"Output values don't match expected. Got { output } , expected { expected_output } " ,
10411041 )
10421042
1043+ @expand (
1044+ [
1045+ # Test case 1: Basic 1D convolution with int8 weights
1046+ (
1047+ "basic_int8_weights" ,
1048+ torch .tensor (
1049+ [[[1.0 , 2.0 , 3.0 , 4.0 , 5.0 ]]], dtype = torch .float32
1050+ ), # src: 1x1x5
1051+ torch .tensor ([[[1 , - 1 , 2 ]]], dtype = torch .int8 ), # weight: 1x1x3
1052+ 0.1 , # w_scale
1053+ torch .tensor ([1 ], dtype = torch .int8 ), # bias: 1
1054+ 0.2 , # b_scale
1055+ torch .tensor (
1056+ [[[0.7 , 0.9 , 1.1 ]]], dtype = torch .float32
1057+ ), # expected: conv1d result
1058+ ),
1059+ # Test case 2: Multiple input channels
1060+ (
1061+ "multi_input_channels" ,
1062+ torch .tensor (
1063+ [[[1.0 , 2.0 , 3.0 ], [4.0 , 5.0 , 6.0 ]]], dtype = torch .float32
1064+ ), # src: 1x2x3
1065+ torch .tensor ([[[2 , 1 ], [1 , 2 ]]], dtype = torch .int8 ), # weight: 1x2x2
1066+ 0.5 , # w_scale
1067+ torch .tensor ([1 ], dtype = torch .int8 ), # bias: 1
1068+ 1.0 , # b_scale
1069+ torch .tensor ([[[10.0 , 13.0 ]]], dtype = torch .float32 ), # expected
1070+ ),
1071+ # Test case 3: Multiple output channels
1072+ (
1073+ "multi_output_channels" ,
1074+ torch .tensor (
1075+ [[[1.0 , 2.0 , 3.0 , 4.0 ]]], dtype = torch .float32
1076+ ), # src: 1x1x4
1077+ torch .tensor ([[[1 , - 1 ]], [[2 , 0 ]]], dtype = torch .int8 ), # weight: 2x1x2
1078+ 0.25 , # w_scale
1079+ torch .tensor ([0 , 1 ], dtype = torch .int8 ), # bias: 2
1080+ 0.5 , # b_scale
1081+ torch .tensor (
1082+ [[[- 0.25 , - 0.25 , - 0.25 ], [1.0 , 1.5 , 2.0 ]]], dtype = torch .float32
1083+ ), # expected
1084+ ),
1085+ # Test case 4: Batch size > 1
1086+ (
1087+ "batch_size_2" ,
1088+ torch .tensor (
1089+ [[[1.0 , 2.0 , 3.0 ]], [[4.0 , 5.0 , 6.0 ]]], dtype = torch .float32
1090+ ), # src: 2x1x3
1091+ torch .tensor ([[[1 , 1 ]]], dtype = torch .int8 ), # weight: 1x1x2
1092+ 1.0 , # w_scale
1093+ torch .tensor ([0 ], dtype = torch .int8 ), # bias: 1
1094+ 1.0 , # b_scale
1095+ torch .tensor (
1096+ [[[3.0 , 5.0 ]], [[9.0 , 11.0 ]]], dtype = torch .float32
1097+ ), # expected
1098+ ),
1099+ # Test case 5: Zero weights and bias
1100+ (
1101+ "zero_weights_bias" ,
1102+ torch .tensor ([[[1.0 , 2.0 , 3.0 ]]], dtype = torch .float32 ), # src: 1x1x3
1103+ torch .tensor ([[[0 , 0 ]]], dtype = torch .int8 ), # weight: 1x1x2
1104+ 0.1 , # w_scale
1105+ torch .tensor ([0 ], dtype = torch .int8 ), # bias: 1
1106+ 1.0 , # b_scale
1107+ torch .tensor ([[[0.0 , 0.0 ]]], dtype = torch .float32 ), # expected
1108+ ),
1109+ # Test case 6: Negative weights
1110+ (
1111+ "negative_weights" ,
1112+ torch .tensor ([[[2.0 , 4.0 , 6.0 ]]], dtype = torch .float32 ), # src: 1x1x3
1113+ torch .tensor ([[[- 2 , - 1 ]]], dtype = torch .int8 ), # weight: 1x1x2
1114+ 0.5 , # w_scale
1115+ torch .tensor ([2 ], dtype = torch .float32 ), # bias: 1
1116+ 1.0 , # b_scale
1117+ torch .tensor ([[[- 2.0 , - 5.0 ]]], dtype = torch .float32 ), # expected
1118+ ),
1119+ ]
1120+ )
1121+ def test_quantized_w8a32_conv (
1122+ self ,
1123+ name : str ,
1124+ src : torch .Tensor ,
1125+ weight : torch .Tensor ,
1126+ w_scale : float ,
1127+ bias : torch .Tensor ,
1128+ b_scale : float ,
1129+ expected_output : torch .Tensor ,
1130+ ) -> None :
1131+ output = torch .ops .cadence .quantized_w8a32_conv (
1132+ src , weight , w_scale , bias , b_scale
1133+ )
1134+
1135+ # Verify output properties
1136+ self .assertEqual (
1137+ output .dtype ,
1138+ torch .float32 ,
1139+ f"Output dtype should be float32 in { name } " ,
1140+ )
1141+ self .assertEqual (
1142+ output .shape ,
1143+ expected_output .shape ,
1144+ f"Output shape should match expected shape in { name } " ,
1145+ )
1146+
1147+ # Verify output matches expected values
1148+ self .assertTrue (
1149+ torch .allclose (output , expected_output , rtol = 1e-4 , atol = 1e-4 ),
1150+ f"Output values don't match expected in { name } . Got { output } , expected { expected_output } " ,
1151+ )
1152+
10431153 @expand (
10441154 [
10451155 # Test case 1: Basic int8 case with negative scale
0 commit comments