@@ -141,59 +141,139 @@ def test_quantized_add(
141141 @expand (
142142 [
143143 # Test case 1: 1x2 input, 1x2 weight (1 output feature)
144- (
145- torch .Size ([1 , 2 ]), # src_shape: 1 sample, 2 input features
146- torch .Size ([1 , 2 ]), # weight_shape: 1 output feature, 2 input features
147- 0 , # in_zero_point
148- torch .tensor ([0 , 0 ], dtype = torch .int8 ), # weight_zero_point
149- torch .tensor (
150- [1073741824 ], dtype = torch .int32
151- ), # out_multiplier (0.5 * 2^31)
152- torch .tensor ([0 ], dtype = torch .int8 ), # out_shift
153- 0 , # out_zero_point
154- torch .tensor ([[- 2 ]], dtype = torch .int8 ), # expected_output
155- ),
144+ * [
145+ (
146+ torch .Size ([1 , 2 ]), # src_shape: 1 sample, 2 input features
147+ torch .Size (
148+ [1 , 2 ]
149+ ), # weight_shape: 1 output feature, 2 input features
150+ 0 , # in_zero_point
151+ torch .tensor ([0 , 0 ], dtype = dtype ), # weight_zero_point
152+ torch .tensor (
153+ [1073741824 ], dtype = torch .int32
154+ ), # out_multiplier (0.5 * 2^31)
155+ torch .tensor ([0 ], dtype = torch .int64 ), # out_shift
156+ 0 , # out_zero_point
157+ torch .tensor ([[- 2 ]], dtype = dtype ), # expected_output
158+ per_tensor ,
159+ )
160+ for (per_tensor , dtype ) in (
161+ (False , torch .int8 ),
162+ (True , torch .int8 ),
163+ (True , torch .uint8 ),
164+ )
165+ ],
156166 # Test case 2: 1x3 input, 2x3 weight (2 output features)
157- (
158- torch .Size ([1 , 3 ]), # src_shape: 1 sample, 3 input features
159- torch .Size ([2 , 3 ]), # weight_shape: 2 output features, 3 input features
160- 0 , # in_zero_point
161- torch .tensor ([0 , 0 , 0 ], dtype = torch .int8 ), # weight_zero_point
162- torch .tensor (
163- [1073741824 ], dtype = torch .int32
164- ), # out_multiplier (0.5 * 2^31)
165- torch .tensor ([0 ], dtype = torch .int8 ), # out_shift
166- 0 , # out_zero_point
167- torch .tensor ([[- 10 , - 30 ]], dtype = torch .int8 ), # expected_output
168- ),
167+ * [
168+ (
169+ torch .Size ([1 , 3 ]), # src_shape: 1 sample, 3 input features
170+ torch .Size (
171+ [2 , 3 ]
172+ ), # weight_shape: 2 output features, 3 input features
173+ 0 , # in_zero_point
174+ torch .tensor ([0 , 0 , 0 ], dtype = dtype ), # weight_zero_point
175+ torch .tensor (
176+ [1073741824 ], dtype = torch .int32
177+ ), # out_multiplier (0.5 * 2^31)
178+ torch .tensor ([0 ], dtype = torch .int64 ), # out_shift
179+ 0 , # out_zero_point
180+ torch .tensor ([[- 10 , - 30 ]], dtype = dtype ), # expected_output
181+ per_tensor ,
182+ )
183+ for (per_tensor , dtype ) in (
184+ (False , torch .int8 ),
185+ (True , torch .int8 ),
186+ (True , torch .uint8 ),
187+ )
188+ ],
169189 # Test case 3: Batch case with different dimensions
170- (
171- torch .Size ([1 , 2 , 2 ]), # src_shape: batch=1, seq=2, features=2
172- torch .Size ([3 , 2 ]), # weight_shape: 3 output features, 2 input features
173- 0 , # in_zero_point
174- torch .tensor ([0 , 0 ], dtype = torch .int8 ), # weight_zero_point
175- torch .tensor (
176- [1073741824 ], dtype = torch .int32
177- ), # out_multiplier (0.5 * 2^31)
178- torch .tensor ([0 ], dtype = torch .int8 ), # out_shift
179- 0 , # out_zero_point
180- torch .tensor (
181- [[[- 2 , - 8 , - 14 ], [- 6 , - 28 , - 50 ]]], dtype = torch .int8
182- ), # expected_output
183- ),
190+ * [
191+ (
192+ torch .Size ([1 , 2 , 2 ]), # src_shape: batch=1, seq=2, features=2
193+ torch .Size (
194+ [3 , 2 ]
195+ ), # weight_shape: 3 output features, 2 input features
196+ 0 , # in_zero_point
197+ torch .tensor ([0 , 0 ], dtype = dtype ), # weight_zero_point
198+ torch .tensor (
199+ [1073741824 ], dtype = torch .int32
200+ ), # out_multiplier (0.5 * 2^31)
201+ torch .tensor ([0 ], dtype = torch .int64 ), # out_shift
202+ 0 , # out_zero_point
203+ torch .tensor (
204+ [[[- 2 , - 8 , - 14 ], [- 6 , - 28 , - 50 ]]], dtype = dtype
205+ ), # expected_output
206+ per_tensor ,
207+ )
208+ for (per_tensor , dtype ) in (
209+ (False , torch .int8 ),
210+ (True , torch .int8 ),
211+ (True , torch .uint8 ),
212+ )
213+ ],
184214 # Test case 4: Non-zero zero points
185- (
186- torch .Size ([1 , 2 ]), # src_shape: 1 sample, 2 input features
187- torch .Size ([2 , 2 ]), # weight_shape: 2 output feature, 1 input feature
188- 2 , # in_zero_point
189- torch .tensor ([1 , 1 ], dtype = torch .int8 ), # weight_zero_point
190- torch .tensor (
191- [268435456 ], dtype = torch .int32
192- ), # out_multiplier (1.0 * 2^31)
193- torch .tensor ([0 ]), # out_shift
194- 1 , # out_zero_point
195- torch .tensor ([[- 15 , 25 ]], dtype = torch .int8 ), # expected_output
196- ),
215+ * [
216+ (
217+ torch .Size ([1 , 2 ]), # src_shape: 1 sample, 2 input features
218+ torch .Size (
219+ [2 , 2 ]
220+ ), # weight_shape: 2 output feature, 1 input feature
221+ 2 , # in_zero_point
222+ torch .tensor ([1 , 1 ], dtype = dtype ), # weight_zero_point
223+ torch .tensor (
224+ [268435456 ], dtype = torch .int32
225+ ), # out_multiplier (1.0 * 2^31)
226+ torch .tensor ([0 ], dtype = torch .int64 ), # out_shift
227+ 1 , # out_zero_point
228+ torch .tensor ([[- 15 , 25 ]], dtype = dtype ), # expected_output
229+ per_tensor ,
230+ )
231+ for (per_tensor , dtype ) in (
232+ (False , torch .int8 ),
233+ (True , torch .int8 ),
234+ (True , torch .uint8 ),
235+ )
236+ ],
237+ # Test case 5: Non-uniform weight zero points
238+ * [
239+ (
240+ torch .Size ([1 , 2 ]), # src_shape: 1 sample, 2 input features
241+ torch .Size (
242+ [2 , 2 ]
243+ ), # weight_shape: 2 output feature, 1 input feature
244+ 2 , # in_zero_point
245+ torch .tensor ([1 , 2 ], dtype = dtype ), # weight_zero_point
246+ torch .tensor (
247+ [268435456 ], dtype = torch .int32
248+ ), # out_multiplier (1.0 * 2^31)
249+ torch .tensor ([0 ], dtype = torch .int64 ), # out_shift
250+ 1 , # out_zero_point
251+ torch .tensor ([[- 23 , 17 ]], dtype = dtype ), # expected_output
252+ False ,
253+ )
254+ for dtype in (torch .int8 , torch .uint8 )
255+ ],
256+ # Test case 6: Non-zero out_shift (shift=1)
257+ * [
258+ (
259+ torch .Size ([1 , 2 ]), # src_shape: 1 sample, 2 input features
260+ torch .Size (
261+ [2 , 2 ]
262+ ), # weight_shape: 2 output features, 2 input features
263+ 2 , # in_zero_point
264+ torch .tensor ([1 , 1 ], dtype = dtype ), # weight_zero_point
265+ torch .tensor (
266+ [268435456 ], dtype = torch .int32
267+ ), # out_multiplier (0.125 * 2^31)
268+ torch .tensor (
269+ [1 ], dtype = torch .int64
270+ ), # out_shift (shift=1, doubles the scale)
271+ 1 , # out_zero_point
272+ torch .tensor ([[- 7 , 13 ]], dtype = dtype ), # expected_output
273+ per_tensor ,
274+ )
275+ for (per_tensor , dtype ) in ((False , torch .int8 ), (True , torch .int8 ))
276+ ],
197277 ]
198278 )
199279 def test_quantized_linear (
@@ -206,6 +286,7 @@ def test_quantized_linear(
206286 out_shift : torch .Tensor ,
207287 out_zero_point : int ,
208288 expected_output : torch .Tensor ,
289+ per_tensor : bool ,
209290 ) -> None :
210291 src = (
211292 torch .arange (np .prod (src_shape ))
@@ -217,8 +298,28 @@ def test_quantized_linear(
217298 .reshape (weight_shape )
218299 .to (expected_output .dtype )
219300 )
220- bias = torch .arange (weight_shape [0 ]).to (expected_output .dtype )
221- output = torch .ops .cadence .quantized_linear (
301+ bias = torch .arange (weight_shape [0 ]).to (torch .int32 )
302+ if per_tensor :
303+ weight_zero_point = weight_zero_point [0 ]
304+ out_multiplier = out_multiplier [0 ]
305+ out_shift = out_shift [0 ]
306+
307+ if per_tensor :
308+ match expected_output .dtype :
309+ case torch .int8 :
310+ linear_op = (
311+ torch .ops .cadence .quantized_linear_asym8sxasym8s_asym8s .per_tensor
312+ )
313+ case torch .uint8 :
314+ linear_op = (
315+ torch .ops .cadence .quantized_linear_asym8uxasym8u_asym8u .per_tensor
316+ )
317+ case _:
318+ linear_op = torch .ops .cadence .quantized_linear .per_tensor
319+ else :
320+ linear_op = torch .ops .cadence .quantized_linear
321+
322+ output = linear_op (
222323 src ,
223324 weight ,
224325 bias ,
0 commit comments