@@ -141,59 +141,139 @@ def test_quantized_add(
141
141
@expand (
142
142
[
143
143
# Test case 1: 1x2 input, 1x2 weight (1 output feature)
144
- (
145
- torch .Size ([1 , 2 ]), # src_shape: 1 sample, 2 input features
146
- torch .Size ([1 , 2 ]), # weight_shape: 1 output feature, 2 input features
147
- 0 , # in_zero_point
148
- torch .tensor ([0 , 0 ], dtype = torch .int8 ), # weight_zero_point
149
- torch .tensor (
150
- [1073741824 ], dtype = torch .int32
151
- ), # out_multiplier (0.5 * 2^31)
152
- torch .tensor ([0 ], dtype = torch .int8 ), # out_shift
153
- 0 , # out_zero_point
154
- torch .tensor ([[- 2 ]], dtype = torch .int8 ), # expected_output
155
- ),
144
+ * [
145
+ (
146
+ torch .Size ([1 , 2 ]), # src_shape: 1 sample, 2 input features
147
+ torch .Size (
148
+ [1 , 2 ]
149
+ ), # weight_shape: 1 output feature, 2 input features
150
+ 0 , # in_zero_point
151
+ torch .tensor ([0 , 0 ], dtype = dtype ), # weight_zero_point
152
+ torch .tensor (
153
+ [1073741824 ], dtype = torch .int32
154
+ ), # out_multiplier (0.5 * 2^31)
155
+ torch .tensor ([0 ], dtype = torch .int64 ), # out_shift
156
+ 0 , # out_zero_point
157
+ torch .tensor ([[- 2 ]], dtype = dtype ), # expected_output
158
+ per_tensor ,
159
+ )
160
+ for (per_tensor , dtype ) in (
161
+ (False , torch .int8 ),
162
+ (True , torch .int8 ),
163
+ (True , torch .uint8 ),
164
+ )
165
+ ],
156
166
# Test case 2: 1x3 input, 2x3 weight (2 output features)
157
- (
158
- torch .Size ([1 , 3 ]), # src_shape: 1 sample, 3 input features
159
- torch .Size ([2 , 3 ]), # weight_shape: 2 output features, 3 input features
160
- 0 , # in_zero_point
161
- torch .tensor ([0 , 0 , 0 ], dtype = torch .int8 ), # weight_zero_point
162
- torch .tensor (
163
- [1073741824 ], dtype = torch .int32
164
- ), # out_multiplier (0.5 * 2^31)
165
- torch .tensor ([0 ], dtype = torch .int8 ), # out_shift
166
- 0 , # out_zero_point
167
- torch .tensor ([[- 10 , - 30 ]], dtype = torch .int8 ), # expected_output
168
- ),
167
+ * [
168
+ (
169
+ torch .Size ([1 , 3 ]), # src_shape: 1 sample, 3 input features
170
+ torch .Size (
171
+ [2 , 3 ]
172
+ ), # weight_shape: 2 output features, 3 input features
173
+ 0 , # in_zero_point
174
+ torch .tensor ([0 , 0 , 0 ], dtype = dtype ), # weight_zero_point
175
+ torch .tensor (
176
+ [1073741824 ], dtype = torch .int32
177
+ ), # out_multiplier (0.5 * 2^31)
178
+ torch .tensor ([0 ], dtype = torch .int64 ), # out_shift
179
+ 0 , # out_zero_point
180
+ torch .tensor ([[- 10 , - 30 ]], dtype = dtype ), # expected_output
181
+ per_tensor ,
182
+ )
183
+ for (per_tensor , dtype ) in (
184
+ (False , torch .int8 ),
185
+ (True , torch .int8 ),
186
+ (True , torch .uint8 ),
187
+ )
188
+ ],
169
189
# Test case 3: Batch case with different dimensions
170
- (
171
- torch .Size ([1 , 2 , 2 ]), # src_shape: batch=1, seq=2, features=2
172
- torch .Size ([3 , 2 ]), # weight_shape: 3 output features, 2 input features
173
- 0 , # in_zero_point
174
- torch .tensor ([0 , 0 ], dtype = torch .int8 ), # weight_zero_point
175
- torch .tensor (
176
- [1073741824 ], dtype = torch .int32
177
- ), # out_multiplier (0.5 * 2^31)
178
- torch .tensor ([0 ], dtype = torch .int8 ), # out_shift
179
- 0 , # out_zero_point
180
- torch .tensor (
181
- [[[- 2 , - 8 , - 14 ], [- 6 , - 28 , - 50 ]]], dtype = torch .int8
182
- ), # expected_output
183
- ),
190
+ * [
191
+ (
192
+ torch .Size ([1 , 2 , 2 ]), # src_shape: batch=1, seq=2, features=2
193
+ torch .Size (
194
+ [3 , 2 ]
195
+ ), # weight_shape: 3 output features, 2 input features
196
+ 0 , # in_zero_point
197
+ torch .tensor ([0 , 0 ], dtype = dtype ), # weight_zero_point
198
+ torch .tensor (
199
+ [1073741824 ], dtype = torch .int32
200
+ ), # out_multiplier (0.5 * 2^31)
201
+ torch .tensor ([0 ], dtype = torch .int64 ), # out_shift
202
+ 0 , # out_zero_point
203
+ torch .tensor (
204
+ [[[- 2 , - 8 , - 14 ], [- 6 , - 28 , - 50 ]]], dtype = dtype
205
+ ), # expected_output
206
+ per_tensor ,
207
+ )
208
+ for (per_tensor , dtype ) in (
209
+ (False , torch .int8 ),
210
+ (True , torch .int8 ),
211
+ (True , torch .uint8 ),
212
+ )
213
+ ],
184
214
# Test case 4: Non-zero zero points
185
- (
186
- torch .Size ([1 , 2 ]), # src_shape: 1 sample, 2 input features
187
- torch .Size ([2 , 2 ]), # weight_shape: 2 output feature, 1 input feature
188
- 2 , # in_zero_point
189
- torch .tensor ([1 , 1 ], dtype = torch .int8 ), # weight_zero_point
190
- torch .tensor (
191
- [268435456 ], dtype = torch .int32
192
- ), # out_multiplier (1.0 * 2^31)
193
- torch .tensor ([0 ]), # out_shift
194
- 1 , # out_zero_point
195
- torch .tensor ([[- 15 , 25 ]], dtype = torch .int8 ), # expected_output
196
- ),
215
+ * [
216
+ (
217
+ torch .Size ([1 , 2 ]), # src_shape: 1 sample, 2 input features
218
+ torch .Size (
219
+ [2 , 2 ]
220
+ ), # weight_shape: 2 output feature, 1 input feature
221
+ 2 , # in_zero_point
222
+ torch .tensor ([1 , 1 ], dtype = dtype ), # weight_zero_point
223
+ torch .tensor (
224
+ [268435456 ], dtype = torch .int32
225
+ ), # out_multiplier (1.0 * 2^31)
226
+ torch .tensor ([0 ], dtype = torch .int64 ), # out_shift
227
+ 1 , # out_zero_point
228
+ torch .tensor ([[- 15 , 25 ]], dtype = dtype ), # expected_output
229
+ per_tensor ,
230
+ )
231
+ for (per_tensor , dtype ) in (
232
+ (False , torch .int8 ),
233
+ (True , torch .int8 ),
234
+ (True , torch .uint8 ),
235
+ )
236
+ ],
237
+ # Test case 5: Non-uniform weight zero points
238
+ * [
239
+ (
240
+ torch .Size ([1 , 2 ]), # src_shape: 1 sample, 2 input features
241
+ torch .Size (
242
+ [2 , 2 ]
243
+ ), # weight_shape: 2 output feature, 1 input feature
244
+ 2 , # in_zero_point
245
+ torch .tensor ([1 , 2 ], dtype = dtype ), # weight_zero_point
246
+ torch .tensor (
247
+ [268435456 ], dtype = torch .int32
248
+ ), # out_multiplier (1.0 * 2^31)
249
+ torch .tensor ([0 ], dtype = torch .int64 ), # out_shift
250
+ 1 , # out_zero_point
251
+ torch .tensor ([[- 23 , 17 ]], dtype = dtype ), # expected_output
252
+ False ,
253
+ )
254
+ for dtype in (torch .int8 , torch .uint8 )
255
+ ],
256
+ # Test case 6: Non-zero out_shift (shift=1)
257
+ * [
258
+ (
259
+ torch .Size ([1 , 2 ]), # src_shape: 1 sample, 2 input features
260
+ torch .Size (
261
+ [2 , 2 ]
262
+ ), # weight_shape: 2 output features, 2 input features
263
+ 2 , # in_zero_point
264
+ torch .tensor ([1 , 1 ], dtype = dtype ), # weight_zero_point
265
+ torch .tensor (
266
+ [268435456 ], dtype = torch .int32
267
+ ), # out_multiplier (0.125 * 2^31)
268
+ torch .tensor (
269
+ [1 ], dtype = torch .int64
270
+ ), # out_shift (shift=1, doubles the scale)
271
+ 1 , # out_zero_point
272
+ torch .tensor ([[- 7 , 13 ]], dtype = dtype ), # expected_output
273
+ per_tensor ,
274
+ )
275
+ for (per_tensor , dtype ) in ((False , torch .int8 ), (True , torch .int8 ))
276
+ ],
197
277
]
198
278
)
199
279
def test_quantized_linear (
@@ -206,6 +286,7 @@ def test_quantized_linear(
206
286
out_shift : torch .Tensor ,
207
287
out_zero_point : int ,
208
288
expected_output : torch .Tensor ,
289
+ per_tensor : bool ,
209
290
) -> None :
210
291
src = (
211
292
torch .arange (np .prod (src_shape ))
@@ -217,8 +298,28 @@ def test_quantized_linear(
217
298
.reshape (weight_shape )
218
299
.to (expected_output .dtype )
219
300
)
220
- bias = torch .arange (weight_shape [0 ]).to (expected_output .dtype )
221
- output = torch .ops .cadence .quantized_linear (
301
+ bias = torch .arange (weight_shape [0 ]).to (torch .int32 )
302
+ if per_tensor :
303
+ weight_zero_point = weight_zero_point [0 ]
304
+ out_multiplier = out_multiplier [0 ]
305
+ out_shift = out_shift [0 ]
306
+
307
+ if per_tensor :
308
+ match expected_output .dtype :
309
+ case torch .int8 :
310
+ linear_op = (
311
+ torch .ops .cadence .quantized_linear_asym8sxasym8s_asym8s .per_tensor
312
+ )
313
+ case torch .uint8 :
314
+ linear_op = (
315
+ torch .ops .cadence .quantized_linear_asym8uxasym8u_asym8u .per_tensor
316
+ )
317
+ case _:
318
+ linear_op = torch .ops .cadence .quantized_linear .per_tensor
319
+ else :
320
+ linear_op = torch .ops .cadence .quantized_linear
321
+
322
+ output = linear_op (
222
323
src ,
223
324
weight ,
224
325
bias ,
0 commit comments