Skip to content

Commit d28f924

Browse files
authored
Add quantized fully connected ops
Differential Revision: D81942767 Pull Request resolved: #14079
1 parent e4b1d51 commit d28f924

File tree

4 files changed

+67
-29
lines changed

4 files changed

+67
-29
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1771,6 +1771,7 @@ def quantized_fully_connected_meta(
17711771
# src comes in shape [leading_dims, in_dim]
17721772
# weight comes in shape [out_dim, in_dim]
17731773
# output comes in empty with shape [leading_dims, out_dim]
1774+
assert src.shape[0] == 1
17741775
out_size = list(src.size())
17751776
weight_size = list(weight.size())
17761777
assert len(weight_size) == 2
@@ -1793,6 +1794,7 @@ def quantized_fully_connected_per_tensor_meta(
17931794
# src comes in shape [leading_dims, in_dim]
17941795
# weight comes in shape [out_dim, in_dim]
17951796
# output comes in empty with shape [leading_dims, out_dim]
1797+
assert src.shape[0] == 1
17961798
out_size = list(src.size())
17971799
weight_size = list(weight.size())
17981800
assert len(weight_size) == 2
@@ -1815,6 +1817,7 @@ def quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_meta(
18151817
# src comes in shape [leading_dims, in_dim]
18161818
# weight comes in shape [out_dim, in_dim]
18171819
# output comes in empty with shape [leading_dims, out_dim]
1820+
assert src.shape[0] == 1
18181821
out_size = list(src.size())
18191822
weight_size = list(weight.size())
18201823
assert len(weight_size) == 2
@@ -1837,6 +1840,7 @@ def quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_meta(
18371840
# src comes in shape [leading_dims, in_dim]
18381841
# weight comes in shape [out_dim, in_dim]
18391842
# output comes in empty with shape [leading_dims, out_dim]
1843+
assert src.shape[0] == 1
18401844
out_size = list(src.size())
18411845
weight_size = list(weight.size())
18421846
assert len(weight_size) == 2

backends/cadence/aot/ref_implementations.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ def quantized_linear_common(
249249

250250
def quantized_linear_variant(
251251
per_tensor: bool,
252+
fully_connected: bool,
252253
src_dtype: torch.dtype | None = None,
253254
weight_dtype: torch.dtype | None = None,
254255
) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]:
@@ -265,6 +266,10 @@ def variant(
265266
out_zero_point: int,
266267
offset: torch.Tensor | None = None,
267268
) -> torch.Tensor:
269+
if fully_connected and src.shape[0] != 1:
270+
raise ValueError(
271+
"Fully connected quantized linear only supports batch size of 1"
272+
)
268273
if src_dtype and src.dtype != src_dtype:
269274
raise ValueError(
270275
f"src dtype must be {src_dtype}. Got {src.dtype} instead"
@@ -317,25 +322,45 @@ def variant(
317322

318323

319324
@impl(m, "quantized_linear")
320-
@quantized_linear_variant(False)
325+
@quantized_linear_variant(False, False)
321326
def quantized_linear() -> torch.Tensor: ...
322327

323328

324329
@impl(m, "quantized_linear.per_tensor")
325-
@quantized_linear_variant(True)
330+
@quantized_linear_variant(True, False)
326331
def quantized_linear_per_tensor() -> torch.Tensor: ...
327332

328333

329334
@impl(m, "quantized_linear_asym8sxasym8s_asym8s.per_tensor")
330-
@quantized_linear_variant(True, torch.int8, torch.int8)
335+
@quantized_linear_variant(True, False, torch.int8, torch.int8)
331336
def quantized_linear_asym8sxasym8s_asym8s_per_tensor() -> torch.Tensor: ...
332337

333338

334339
@impl(m, "quantized_linear_asym8uxasym8u_asym8u.per_tensor")
335-
@quantized_linear_variant(True, torch.uint8, torch.uint8)
340+
@quantized_linear_variant(True, False, torch.uint8, torch.uint8)
336341
def quantized_linear_asym8uxasym8u_asym8u_per_tensor() -> torch.Tensor: ...
337342

338343

344+
@impl(m, "quantized_fully_connected")
345+
@quantized_linear_variant(False, True)
346+
def quantized_fully_connected() -> torch.Tensor: ...
347+
348+
349+
@impl(m, "quantized_fully_connected.per_tensor")
350+
@quantized_linear_variant(True, True)
351+
def quantized_fully_connected_per_tensor() -> torch.Tensor: ...
352+
353+
354+
@impl(m, "quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor")
355+
@quantized_linear_variant(True, True, torch.int8, torch.int8)
356+
def quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor() -> torch.Tensor: ...
357+
358+
359+
@impl(m, "quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor")
360+
@quantized_linear_variant(True, True, torch.uint8, torch.uint8)
361+
def quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor() -> torch.Tensor: ...
362+
363+
339364
@impl(m, "quantized_layer_norm.per_tensor")
340365
def quantized_layer_norm_per_tensor(
341366
input_tensor: torch.Tensor,

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -307,36 +307,45 @@ def test_quantized_linear(
307307
if per_tensor:
308308
match expected_output.dtype:
309309
case torch.int8:
310-
linear_op = (
311-
torch.ops.cadence.quantized_linear_asym8sxasym8s_asym8s.per_tensor
310+
linear_ops = (
311+
torch.ops.cadence.quantized_linear_asym8sxasym8s_asym8s.per_tensor,
312+
torch.ops.cadence.quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor,
312313
)
313314
case torch.uint8:
314-
linear_op = (
315-
torch.ops.cadence.quantized_linear_asym8uxasym8u_asym8u.per_tensor
315+
linear_ops = (
316+
torch.ops.cadence.quantized_linear_asym8uxasym8u_asym8u.per_tensor,
317+
torch.ops.cadence.quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor,
316318
)
317319
case _:
318-
linear_op = torch.ops.cadence.quantized_linear.per_tensor
320+
linear_ops = (
321+
torch.ops.cadence.quantized_linear.per_tensor,
322+
torch.ops.cadence.quantized_fully_connected.per_tensor,
323+
)
319324
else:
320-
linear_op = torch.ops.cadence.quantized_linear
325+
linear_ops = (
326+
torch.ops.cadence.quantized_linear,
327+
torch.ops.cadence.quantized_fully_connected,
328+
)
321329

322-
output = linear_op(
323-
src,
324-
weight,
325-
bias,
326-
in_zero_point,
327-
weight_zero_point,
328-
out_multiplier,
329-
out_shift,
330-
out_zero_point,
331-
typing.cast(torch.Tensor, None),
332-
)
330+
for linear_op in linear_ops:
331+
output = linear_op(
332+
src,
333+
weight,
334+
bias,
335+
in_zero_point,
336+
weight_zero_point,
337+
out_multiplier,
338+
out_shift,
339+
out_zero_point,
340+
typing.cast(torch.Tensor, None),
341+
)
333342

334-
self.assertTrue(output.dtype == expected_output.dtype, "Dtype mismatch")
343+
self.assertTrue(output.dtype == expected_output.dtype, "Dtype mismatch")
335344

336-
self.assertTrue(
337-
torch.equal(output, expected_output),
338-
f"Values don't match: got {output}, expected {expected_output}",
339-
)
345+
self.assertTrue(
346+
torch.equal(output, expected_output),
347+
f"Values don't match: got {output}, expected {expected_output}",
348+
)
340349

341350
@expand(
342351
[

backends/cadence/aot/tests/test_type_dispatch_passes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
class TestTypeDispatchPasses(unittest.TestCase):
2121
def test_int8_dispatch_quantized_fully_connected(self) -> None:
2222
"""Test int8 x int8 inputs should dispatch to asym8sxasym8s_asym8s variant"""
23-
x = torch.randint(-128, 127, (2, 3), dtype=torch.int8)
23+
x = torch.randint(-128, 127, (1, 3), dtype=torch.int8)
2424
w = torch.randint(-128, 127, (4, 3), dtype=torch.int8)
2525
b = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32)
2626
gm = single_op_builder(
@@ -46,7 +46,7 @@ def test_int8_dispatch_quantized_fully_connected(self) -> None:
4646

4747
def test_uint8_dispatch_quantized_fully_connected(self) -> None:
4848
"""Test uint8 x uint8 inputs should dispatch to asym8uxasym8u_asym8u variant"""
49-
x = torch.randint(0, 255, (2, 3), dtype=torch.uint8)
49+
x = torch.randint(0, 255, (1, 3), dtype=torch.uint8)
5050
w = torch.randint(0, 255, (4, 3), dtype=torch.uint8)
5151
b = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32)
5252
gm = single_op_builder(
@@ -124,7 +124,7 @@ def test_uint8_quantized_linear_dispatch(self) -> None:
124124

125125
def test_mixed_types_error(self) -> None:
126126
"""Test mixed int8/uint8 inputs should raise RuntimeError"""
127-
x = torch.randint(-128, 127, (2, 3), dtype=torch.int8)
127+
x = torch.randint(-128, 127, (1, 3), dtype=torch.int8)
128128
w = torch.randint(0, 255, (4, 3), dtype=torch.uint8)
129129
b = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32)
130130
gm = single_op_builder(

0 commit comments

Comments
 (0)