Skip to content

Commit 7cb6cd4

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Add quantized fully connected ops (pytorch#14079)
Summary: Quantized fully connected are just aliases for quantized_linear, so created all aliases. Differential Revision: D81942767
1 parent e790a3d commit 7cb6cd4

File tree

3 files changed

+64
-26
lines changed

3 files changed

+64
-26
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1771,6 +1771,7 @@ def quantized_fully_connected_meta(
17711771
# src comes in shape [leading_dims, in_dim]
17721772
# weight comes in shape [out_dim, in_dim]
17731773
# output comes in empty with shape [leading_dims, out_dim]
1774+
assert src.shape[0] == 1
17741775
out_size = list(src.size())
17751776
weight_size = list(weight.size())
17761777
assert len(weight_size) == 2
@@ -1793,6 +1794,7 @@ def quantized_fully_connected_per_tensor_meta(
17931794
# src comes in shape [leading_dims, in_dim]
17941795
# weight comes in shape [out_dim, in_dim]
17951796
# output comes in empty with shape [leading_dims, out_dim]
1797+
assert src.shape[0] == 1
17961798
out_size = list(src.size())
17971799
weight_size = list(weight.size())
17981800
assert len(weight_size) == 2
@@ -1815,6 +1817,7 @@ def quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_meta(
18151817
# src comes in shape [leading_dims, in_dim]
18161818
# weight comes in shape [out_dim, in_dim]
18171819
# output comes in empty with shape [leading_dims, out_dim]
1820+
assert src.shape[0] == 1
18181821
out_size = list(src.size())
18191822
weight_size = list(weight.size())
18201823
assert len(weight_size) == 2
@@ -1837,6 +1840,7 @@ def quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_meta(
18371840
# src comes in shape [leading_dims, in_dim]
18381841
# weight comes in shape [out_dim, in_dim]
18391842
# output comes in empty with shape [leading_dims, out_dim]
1843+
assert src.shape[0] == 1
18401844
out_size = list(src.size())
18411845
weight_size = list(weight.size())
18421846
assert len(weight_size) == 2

backends/cadence/aot/ref_implementations.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ def quantized_linear_common(
249249

250250
def quantized_linear_variant(
251251
per_tensor: bool,
252+
fully_connected: bool,
252253
src_dtype: torch.dtype | None = None,
253254
weight_dtype: torch.dtype | None = None,
254255
) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]:
@@ -265,6 +266,10 @@ def variant(
265266
out_zero_point: int,
266267
offset: torch.Tensor | None = None,
267268
) -> torch.Tensor:
269+
if fully_connected and src.shape[0] != 1:
270+
raise ValueError(
271+
"Fully connected quantized linear only supports batch size of 1"
272+
)
268273
if src_dtype and src.dtype != src_dtype:
269274
raise ValueError(
270275
f"src dtype must be {src_dtype}. Got {src.dtype} instead"
@@ -317,25 +322,45 @@ def variant(
317322

318323

319324
@impl(m, "quantized_linear")
320-
@quantized_linear_variant(False)
325+
@quantized_linear_variant(False, False)
321326
def quantized_linear() -> torch.Tensor: ...
322327

323328

324329
@impl(m, "quantized_linear.per_tensor")
325-
@quantized_linear_variant(True)
330+
@quantized_linear_variant(True, False)
326331
def quantized_linear_per_tensor() -> torch.Tensor: ...
327332

328333

329334
@impl(m, "quantized_linear_asym8sxasym8s_asym8s.per_tensor")
330-
@quantized_linear_variant(True, torch.int8, torch.int8)
335+
@quantized_linear_variant(True, False, torch.int8, torch.int8)
331336
def quantized_linear_asym8sxasym8s_asym8s_per_tensor() -> torch.Tensor: ...
332337

333338

334339
@impl(m, "quantized_linear_asym8uxasym8u_asym8u.per_tensor")
335-
@quantized_linear_variant(True, torch.uint8, torch.uint8)
340+
@quantized_linear_variant(True, False, torch.uint8, torch.uint8)
336341
def quantized_linear_asym8uxasym8u_asym8u_per_tensor() -> torch.Tensor: ...
337342

338343

344+
@impl(m, "quantized_fully_connected")
345+
@quantized_linear_variant(False, True)
346+
def quantized_fully_connected() -> torch.Tensor: ...
347+
348+
349+
@impl(m, "quantized_fully_connected.per_tensor")
350+
@quantized_linear_variant(True, True)
351+
def quantized_fully_connected_per_tensor() -> torch.Tensor: ...
352+
353+
354+
@impl(m, "quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor")
355+
@quantized_linear_variant(True, True, torch.int8, torch.int8)
356+
def quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor() -> torch.Tensor: ...
357+
358+
359+
@impl(m, "quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor")
360+
@quantized_linear_variant(True, True, torch.uint8, torch.uint8)
361+
def quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor() -> torch.Tensor: ...
362+
363+
339364
@impl(m, "quantized_layer_norm.per_tensor")
340365
def quantized_layer_norm_per_tensor(
341366
input_tensor: torch.Tensor,

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -307,36 +307,45 @@ def test_quantized_linear(
307307
if per_tensor:
308308
match expected_output.dtype:
309309
case torch.int8:
310-
linear_op = (
311-
torch.ops.cadence.quantized_linear_asym8sxasym8s_asym8s.per_tensor
310+
linear_ops = (
311+
torch.ops.cadence.quantized_linear_asym8sxasym8s_asym8s.per_tensor,
312+
torch.ops.cadence.quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor,
312313
)
313314
case torch.uint8:
314-
linear_op = (
315-
torch.ops.cadence.quantized_linear_asym8uxasym8u_asym8u.per_tensor
315+
linear_ops = (
316+
torch.ops.cadence.quantized_linear_asym8uxasym8u_asym8u.per_tensor,
317+
torch.ops.cadence.quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor,
316318
)
317319
case _:
318-
linear_op = torch.ops.cadence.quantized_linear.per_tensor
320+
linear_ops = (
321+
torch.ops.cadence.quantized_linear.per_tensor,
322+
torch.ops.cadence.quantized_fully_connected.per_tensor,
323+
)
319324
else:
320-
linear_op = torch.ops.cadence.quantized_linear
325+
linear_ops = (
326+
torch.ops.cadence.quantized_linear,
327+
torch.ops.cadence.quantized_fully_connected,
328+
)
321329

322-
output = linear_op(
323-
src,
324-
weight,
325-
bias,
326-
in_zero_point,
327-
weight_zero_point,
328-
out_multiplier,
329-
out_shift,
330-
out_zero_point,
331-
typing.cast(torch.Tensor, None),
332-
)
330+
for linear_op in linear_ops:
331+
output = linear_op(
332+
src,
333+
weight,
334+
bias,
335+
in_zero_point,
336+
weight_zero_point,
337+
out_multiplier,
338+
out_shift,
339+
out_zero_point,
340+
typing.cast(torch.Tensor, None),
341+
)
333342

334-
self.assertTrue(output.dtype == expected_output.dtype, "Dtype mismatch")
343+
self.assertTrue(output.dtype == expected_output.dtype, "Dtype mismatch")
335344

336-
self.assertTrue(
337-
torch.equal(output, expected_output),
338-
f"Values don't match: got {output}, expected {expected_output}",
339-
)
345+
self.assertTrue(
346+
torch.equal(output, expected_output),
347+
f"Values don't match: got {output}, expected {expected_output}",
348+
)
340349

341350
@expand(
342351
[

0 commit comments

Comments
 (0)