@@ -172,7 +172,7 @@ def test_quantized_add(
172
172
torch .tensor (
173
173
[1073741824 ], dtype = torch .int32
174
174
), # out_multiplier (0.5 * 2^31)
175
- torch .tensor ([0 ], dtype = torch .int64 ), # out_shift
175
+ torch .tensor ([0 ], dtype = torch .int32 ), # out_shift
176
176
0 , # out_zero_point
177
177
torch .tensor ([[0 ]], dtype = dtype ), # expected_output
178
178
per_tensor ,
@@ -197,7 +197,7 @@ def test_quantized_add(
197
197
torch .tensor (
198
198
[1073741824 ], dtype = torch .int32
199
199
), # out_multiplier (0.5 * 2^31)
200
- torch .tensor ([0 ], dtype = torch .int64 ), # out_shift
200
+ torch .tensor ([0 ], dtype = torch .int32 ), # out_shift
201
201
0 , # out_zero_point
202
202
torch .tensor ([[- 2 , - 8 ]], dtype = dtype ), # expected_output
203
203
per_tensor ,
@@ -220,7 +220,7 @@ def test_quantized_add(
220
220
torch .tensor (
221
221
[1073741824 ], dtype = torch .int32
222
222
), # out_multiplier (0.5 * 2^31)
223
- torch .tensor ([0 ], dtype = torch .int64 ), # out_shift
223
+ torch .tensor ([0 ], dtype = torch .int32 ), # out_shift
224
224
0 , # out_zero_point
225
225
torch .tensor ([[0 , 0 ]], dtype = dtype ), # expected_output
226
226
per_tensor ,
@@ -244,7 +244,7 @@ def test_quantized_add(
244
244
torch .tensor (
245
245
[1073741824 ], dtype = torch .int32
246
246
), # out_multiplier (0.5 * 2^31)
247
- torch .tensor ([0 ], dtype = torch .int64 ), # out_shift
247
+ torch .tensor ([0 ], dtype = torch .int32 ), # out_shift
248
248
0 , # out_zero_point
249
249
torch .tensor (
250
250
[[[0 , - 2 , - 4 ], [- 2 , - 7 , - 12 ]]], dtype = dtype
@@ -270,7 +270,7 @@ def test_quantized_add(
270
270
torch .tensor (
271
271
[268435456 ], dtype = torch .int32
272
272
), # out_multiplier (1.0 * 2^31)
273
- torch .tensor ([0 ], dtype = torch .int64 ), # out_shift
273
+ torch .tensor ([0 ], dtype = torch .int32 ), # out_shift
274
274
1 , # out_zero_point
275
275
torch .tensor ([[1 , 1 ]], dtype = dtype ), # expected_output
276
276
per_tensor ,
@@ -295,7 +295,7 @@ def test_quantized_add(
295
295
torch .tensor (
296
296
[268435456 ], dtype = torch .int32
297
297
), # out_multiplier (1.0 * 2^31)
298
- torch .tensor ([0 ], dtype = torch .int64 ), # out_shift
298
+ torch .tensor ([0 ], dtype = torch .int32 ), # out_shift
299
299
1 , # out_zero_point
300
300
torch .tensor ([[1 , 1 ]], dtype = dtype ), # expected_output
301
301
False ,
@@ -317,7 +317,7 @@ def test_quantized_add(
317
317
[268435456 ], dtype = torch .int32
318
318
), # out_multiplier (0.125 * 2^31)
319
319
torch .tensor (
320
- [1 ], dtype = torch .int64
320
+ [1 ], dtype = torch .int32
321
321
), # out_shift (shift=1, doubles the scale)
322
322
1 , # out_zero_point
323
323
torch .tensor ([[1 , 2 ]], dtype = dtype ), # expected_output
@@ -339,7 +339,7 @@ def test_quantized_add(
339
339
[268435456 ], dtype = torch .int32
340
340
), # out_multiplier (0.125 * 2^31)
341
341
torch .tensor (
342
- [1 ], dtype = torch .int64
342
+ [1 ], dtype = torch .int32
343
343
), # out_shift (shift=1, doubles the scale)
344
344
1 , # out_zero_point
345
345
torch .tensor ([[1 , 2 ]], dtype = dtype ), # expected_output
0 commit comments