@@ -36,12 +36,11 @@ def test_quantize_per_tensor(
3636 ) -> None :
3737 input_tensor = torch .tensor ([input_value ])
3838 scale = (f_max - f_min ) / (q_max - q_min )
39- inv_scale = 1.0 / scale
40- zero_point = round (- f_min * inv_scale ) + q_min
39+ zero_point = round (- f_min * 1 / scale ) + q_min
4140 expected_output = torch .tensor ([expected_value ], dtype = target_dtype )
4241
4342 output = torch .ops .cadence .quantize_per_tensor (
44- input_tensor , inv_scale , zero_point , q_min , q_max , target_dtype
43+ input_tensor , scale , zero_point , q_min , q_max , target_dtype
4544 )
4645
4746 self .assertEqual (
@@ -85,7 +84,7 @@ def test_dequantize_per_tensor(
8584 expected_output = torch .tensor ([expected_value ], dtype = torch .float32 )
8685
8786 output = torch .ops .cadence .dequantize_per_tensor (
88- input_tensor , scale , zero_point , q_min , q_max , torch . float32
87+ input_tensor , scale , zero_point , q_min , q_max , input_tensor . dtype
8988 )
9089
9190 self .assertEqual (
@@ -175,7 +174,7 @@ def test_quantized_add(
175174 ), # out_multiplier (0.5 * 2^31)
176175 torch .tensor ([0 ], dtype = torch .int64 ), # out_shift
177176 0 , # out_zero_point
178- torch .tensor ([[- 2 ]], dtype = dtype ), # expected_output
177+ torch .tensor ([[0 ]], dtype = dtype ), # expected_output
179178 per_tensor ,
180179 False ,
181180 False ,
@@ -200,14 +199,36 @@ def test_quantized_add(
200199 ), # out_multiplier (0.5 * 2^31)
201200 torch .tensor ([0 ], dtype = torch .int64 ), # out_shift
202201 0 , # out_zero_point
203- torch .tensor ([[- 10 , - 30 ]], dtype = dtype ), # expected_output
202+ torch .tensor ([[- 2 , - 8 ]], dtype = dtype ), # expected_output
204203 per_tensor ,
205204 False ,
206205 False ,
207206 )
208207 for (per_tensor , dtype ) in (
209208 (False , torch .int8 ),
210209 (True , torch .int8 ),
210+ )
211+ ],
212+ * [
213+ (
214+ torch .Size ([1 , 3 ]), # src_shape: 1 sample, 3 input features
215+ torch .Size (
216+ [2 , 3 ]
217+ ), # weight_shape: 2 output features, 3 input features
218+ 0 , # in_zero_point
219+ torch .tensor ([0 , 0 , 0 ], dtype = dtype ), # weight_zero_point
220+ torch .tensor (
221+ [1073741824 ], dtype = torch .int32
222+ ), # out_multiplier (0.5 * 2^31)
223+ torch .tensor ([0 ], dtype = torch .int64 ), # out_shift
224+ 0 , # out_zero_point
225+ torch .tensor ([[0 , 0 ]], dtype = dtype ), # expected_output
226+ per_tensor ,
227+ False ,
228+ False ,
229+ )
230+ for (per_tensor , dtype ) in (
231+ (False , torch .uint8 ),
211232 (True , torch .uint8 ),
212233 )
213234 ],
@@ -226,7 +247,7 @@ def test_quantized_add(
226247 torch .tensor ([0 ], dtype = torch .int64 ), # out_shift
227248 0 , # out_zero_point
228249 torch .tensor (
229- [[[- 2 , - 8 , - 14 ], [- 6 , - 28 , - 50 ]]], dtype = dtype
250+ [[[0 , - 2 , - 4 ], [- 2 , - 7 , - 12 ]]], dtype = dtype
230251 ), # expected_output
231252 per_tensor ,
232253 False ,
@@ -235,7 +256,6 @@ def test_quantized_add(
235256 for (per_tensor , dtype ) in (
236257 (False , torch .int8 ),
237258 (True , torch .int8 ),
238- (True , torch .uint8 ),
239259 )
240260 ],
241261 # Test case 4: Non-zero zero points
@@ -252,15 +272,15 @@ def test_quantized_add(
252272 ), # out_multiplier (1.0 * 2^31)
253273 torch .tensor ([0 ], dtype = torch .int64 ), # out_shift
254274 1 , # out_zero_point
255- torch .tensor ([[- 15 , 25 ]], dtype = dtype ), # expected_output
275+ torch .tensor ([[1 , 1 ]], dtype = dtype ), # expected_output
256276 per_tensor ,
257277 False ,
258278 False ,
259279 )
260280 for (per_tensor , dtype ) in (
261281 (False , torch .int8 ),
262282 (True , torch .int8 ),
263- (True , torch .uint8 ),
283+ # (True, torch.uint8),
264284 )
265285 ],
266286 # Test case 5: Non-uniform weight zero points
@@ -277,12 +297,12 @@ def test_quantized_add(
277297 ), # out_multiplier (1.0 * 2^31)
278298 torch .tensor ([0 ], dtype = torch .int64 ), # out_shift
279299 1 , # out_zero_point
280- torch .tensor ([[- 23 , 17 ]], dtype = dtype ), # expected_output
300+ torch .tensor ([[1 , 1 ]], dtype = dtype ), # expected_output
281301 False ,
282302 False ,
283303 False ,
284304 )
285- for dtype in (torch .int8 , torch . uint8 )
305+ for dtype in (torch .int8 ,)
286306 ],
287307 # Test case 6: Non-zero out_shift (shift=1)
288308 * [
@@ -300,7 +320,7 @@ def test_quantized_add(
300320 [1 ], dtype = torch .int64
301321 ), # out_shift (shift=1, doubles the scale)
302322 1 , # out_zero_point
303- torch .tensor ([[- 7 , 13 ]], dtype = dtype ), # expected_output
323+ torch .tensor ([[1 , 2 ]], dtype = dtype ), # expected_output
304324 per_tensor ,
305325 False ,
306326 False ,
@@ -322,13 +342,13 @@ def test_quantized_add(
322342 [1 ], dtype = torch .int64
323343 ), # out_shift (shift=1, doubles the scale)
324344 1 , # out_zero_point
325- torch .tensor ([[- 7 , 17 ]], dtype = dtype ), # expected_output
345+ torch .tensor ([[1 , 2 ]], dtype = dtype ), # expected_output
326346 per_tensor ,
327347 matmul ,
328348 transposed_matmul ,
329349 )
330350 for (matmul , transposed_matmul ) in ((True , False ), (True , True ))
331- for (per_tensor , dtype ) in ((True , torch .int8 ), ( True , torch . uint8 ) )
351+ for (per_tensor , dtype ) in ((True , torch .int8 ),)
332352 ],
333353 ]
334354 )
@@ -1045,7 +1065,20 @@ def test_quantized_conv_per_tensor(
10451065 [4 , 2 , 0 , - 2 ], dtype = dtype
10461066 ), # expected: relu(1,3,5,7) = (1,3,5,7) * (-1.0) + 5 = (4,2,0,-2)
10471067 )
1048- for dtype in [torch .int8 , torch .uint8 ]
1068+ for dtype in [torch .int8 ]
1069+ ],
1070+ * [
1071+ (
1072+ "positive_with_shift_unsigned" ,
1073+ torch .tensor ([2 , 4 , 6 , 8 ], dtype = dtype ), # input
1074+ 1 , # X_zero_point
1075+ 5 , # out_zero_point
1076+ 1073741824 , # out_multiplier (0.5 * 2^31)
1077+ 1 , # out_shift (multiply by 2^1 = 2)
1078+ dtype , # dtype
1079+ torch .tensor ([4 , 2 , 0 , 0 ], dtype = dtype ),
1080+ )
1081+ for dtype in [torch .uint8 ]
10491082 ],
10501083 # Test case 4: Non-per-tensor
10511084 * [
0 commit comments