Skip to content

Commit e51326b

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Fixed assumption on out_shift for quantized linear (pytorch#14789)
Summary: out shift should be int32 Reviewed By: hsharma35 Differential Revision: D83875670
1 parent a39866c commit e51326b

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,8 +330,8 @@ def variant(
330330
if out_shift.numel() != 1:
331331
raise ValueError("out_shift must be a scalar")
332332

333-
if out_shift.dtype != torch.int64:
334-
raise ValueError("out_shift must be an int64")
333+
if out_shift.dtype != torch.int32:
334+
raise ValueError("out_shift must be an int32")
335335

336336
_out_shift = int(out_shift.item())
337337
_out_multiplier = int(out_multiplier[0].item())

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def test_quantized_add(
172172
torch.tensor(
173173
[1073741824], dtype=torch.int32
174174
), # out_multiplier (0.5 * 2^31)
175-
torch.tensor([0], dtype=torch.int64), # out_shift
175+
torch.tensor([0], dtype=torch.int32), # out_shift
176176
0, # out_zero_point
177177
torch.tensor([[0]], dtype=dtype), # expected_output
178178
per_tensor,
@@ -197,7 +197,7 @@ def test_quantized_add(
197197
torch.tensor(
198198
[1073741824], dtype=torch.int32
199199
), # out_multiplier (0.5 * 2^31)
200-
torch.tensor([0], dtype=torch.int64), # out_shift
200+
torch.tensor([0], dtype=torch.int32), # out_shift
201201
0, # out_zero_point
202202
torch.tensor([[-2, -8]], dtype=dtype), # expected_output
203203
per_tensor,
@@ -220,7 +220,7 @@ def test_quantized_add(
220220
torch.tensor(
221221
[1073741824], dtype=torch.int32
222222
), # out_multiplier (0.5 * 2^31)
223-
torch.tensor([0], dtype=torch.int64), # out_shift
223+
torch.tensor([0], dtype=torch.int32), # out_shift
224224
0, # out_zero_point
225225
torch.tensor([[0, 0]], dtype=dtype), # expected_output
226226
per_tensor,
@@ -244,7 +244,7 @@ def test_quantized_add(
244244
torch.tensor(
245245
[1073741824], dtype=torch.int32
246246
), # out_multiplier (0.5 * 2^31)
247-
torch.tensor([0], dtype=torch.int64), # out_shift
247+
torch.tensor([0], dtype=torch.int32), # out_shift
248248
0, # out_zero_point
249249
torch.tensor(
250250
[[[0, -2, -4], [-2, -7, -12]]], dtype=dtype
@@ -270,7 +270,7 @@ def test_quantized_add(
270270
torch.tensor(
271271
[268435456], dtype=torch.int32
272272
), # out_multiplier (1.0 * 2^31)
273-
torch.tensor([0], dtype=torch.int64), # out_shift
273+
torch.tensor([0], dtype=torch.int32), # out_shift
274274
1, # out_zero_point
275275
torch.tensor([[1, 1]], dtype=dtype), # expected_output
276276
per_tensor,
@@ -295,7 +295,7 @@ def test_quantized_add(
295295
torch.tensor(
296296
[268435456], dtype=torch.int32
297297
), # out_multiplier (1.0 * 2^31)
298-
torch.tensor([0], dtype=torch.int64), # out_shift
298+
torch.tensor([0], dtype=torch.int32), # out_shift
299299
1, # out_zero_point
300300
torch.tensor([[1, 1]], dtype=dtype), # expected_output
301301
False,
@@ -317,7 +317,7 @@ def test_quantized_add(
317317
[268435456], dtype=torch.int32
318318
), # out_multiplier (0.125 * 2^31)
319319
torch.tensor(
320-
[1], dtype=torch.int64
320+
[1], dtype=torch.int32
321321
), # out_shift (shift=1, doubles the scale)
322322
1, # out_zero_point
323323
torch.tensor([[1, 2]], dtype=dtype), # expected_output
@@ -339,7 +339,7 @@ def test_quantized_add(
339339
[268435456], dtype=torch.int32
340340
), # out_multiplier (0.125 * 2^31)
341341
torch.tensor(
342-
[1], dtype=torch.int64
342+
[1], dtype=torch.int32
343343
), # out_shift (shift=1, doubles the scale)
344344
1, # out_zero_point
345345
torch.tensor([[1, 2]], dtype=dtype), # expected_output

0 commit comments

Comments
 (0)