Skip to content

Commit a22f0a8

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Add quantized fully connected ops
Summary: Quantized fully connected are just aliases for quantized_linear, so created all aliases. Differential Revision: D81942767
1 parent 0a00f15 commit a22f0a8

File tree

2 files changed

+51
-22
lines changed

2 files changed

+51
-22
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,26 @@ def quantized_linear_asym8sxasym8s_asym8s_per_tensor() -> torch.Tensor: ...
336336
def quantized_linear_asym8uxasym8u_asym8u_per_tensor() -> torch.Tensor: ...
337337

338338

339+
@impl(m, "quantized_fully_connected")
340+
@quantized_linear_variant(False)
341+
def quantized_fully_connected() -> torch.Tensor: ...
342+
343+
344+
@impl(m, "quantized_fully_connected.per_tensor")
345+
@quantized_linear_variant(True)
346+
def quantized_fully_connected_per_tensor() -> torch.Tensor: ...
347+
348+
349+
@impl(m, "quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor")
350+
@quantized_linear_variant(True, torch.int8, torch.int8)
351+
def quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor() -> torch.Tensor: ...
352+
353+
354+
@impl(m, "quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor")
355+
@quantized_linear_variant(True, torch.uint8, torch.uint8)
356+
def quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor() -> torch.Tensor: ...
357+
358+
339359
@impl(m, "quantized_layer_norm.per_tensor")
340360
def quantized_layer_norm_per_tensor(
341361
input_tensor: torch.Tensor,

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -307,36 +307,45 @@ def test_quantized_linear(
307307
if per_tensor:
308308
match expected_output.dtype:
309309
case torch.int8:
310-
linear_op = (
311-
torch.ops.cadence.quantized_linear_asym8sxasym8s_asym8s.per_tensor
310+
linear_ops = (
311+
torch.ops.cadence.quantized_linear_asym8sxasym8s_asym8s.per_tensor,
312+
torch.ops.cadence.quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor,
312313
)
313314
case torch.uint8:
314-
linear_op = (
315-
torch.ops.cadence.quantized_linear_asym8uxasym8u_asym8u.per_tensor
315+
linear_ops = (
316+
torch.ops.cadence.quantized_linear_asym8uxasym8u_asym8u.per_tensor,
317+
torch.ops.cadence.quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor,
316318
)
317319
case _:
318-
linear_op = torch.ops.cadence.quantized_linear.per_tensor
320+
linear_ops = (
321+
torch.ops.cadence.quantized_linear.per_tensor,
322+
torch.ops.cadence.quantized_fully_connected.per_tensor,
323+
)
319324
else:
320-
linear_op = torch.ops.cadence.quantized_linear
325+
linear_ops = (
326+
torch.ops.cadence.quantized_linear,
327+
torch.ops.cadence.quantized_fully_connected,
328+
)
321329

322-
output = linear_op(
323-
src,
324-
weight,
325-
bias,
326-
in_zero_point,
327-
weight_zero_point,
328-
out_multiplier,
329-
out_shift,
330-
out_zero_point,
331-
typing.cast(torch.Tensor, None),
332-
)
330+
for linear_op in linear_ops:
331+
output = linear_op(
332+
src,
333+
weight,
334+
bias,
335+
in_zero_point,
336+
weight_zero_point,
337+
out_multiplier,
338+
out_shift,
339+
out_zero_point,
340+
typing.cast(torch.Tensor, None),
341+
)
333342

334-
self.assertTrue(output.dtype == expected_output.dtype, "Dtype mismatch")
343+
self.assertTrue(output.dtype == expected_output.dtype, "Dtype mismatch")
335344

336-
self.assertTrue(
337-
torch.equal(output, expected_output),
338-
f"Values don't match: got {output}, expected {expected_output}",
339-
)
345+
self.assertTrue(
346+
torch.equal(output, expected_output),
347+
f"Values don't match: got {output}, expected {expected_output}",
348+
)
340349

341350
@expand(
342351
[

0 commit comments

Comments
 (0)