Skip to content

Commit b0eb57f

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Fix quantized linear -> quantized fully connected replacement pass + add quantized fully connected per_tensor
Summary: As titled. This allows removing the outer loop unrolling in cases where the input to linear is a vector. Shaves ~10k cyles from WW stage 1 model. Differential Revision: D66208417
1 parent f40daea commit b0eb57f

File tree

1 file changed

+31
-2
lines changed

1 file changed

+31
-2
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,10 @@
146146
"quantized_fully_connected(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
147147
"Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
148148
)
149-
149+
lib.define(
150+
"quantized_fully_connected.per_tensor(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
151+
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
152+
)
150153

151154
# ------------------------------------ #
152155
# Migrated from custom_ops.ymal #
@@ -192,6 +195,10 @@
192195
"quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
193196
"Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
194197
)
198+
lib.define(
199+
"quantized_fully_connected.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
200+
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
201+
)
195202
lib.define(
196203
"quantized_embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, "
197204
"Tensor indices, bool pruned_weights=False, *, Tensor(a!) out) -> Tensor(a!)"
@@ -595,6 +602,28 @@ def quantized_fully_connected_meta(
595602
bias: torch.Tensor,
596603
in_zero_point: int,
597604
weight_zero_point: torch.Tensor,
605+
out_multiplier: torch.Tensor,
606+
out_shift: torch.Tensor,
607+
out_zero_point: int,
608+
offset: Optional[torch.Tensor],
609+
) -> torch.Tensor:
610+
# src comes in shape [leading_dims, in_dim]
611+
# weight comes in shape [out_dim, in_dim]
612+
# output comes in empty with shape [leading_dims, out_dim]
613+
out_size = list(src.size())
614+
weight_size = list(weight.size())
615+
assert len(weight_size) == 2
616+
out_size[-1] = weight_size[0]
617+
return src.new_empty(out_size, dtype=src.dtype)
618+
619+
620+
@register_fake("cadence::quantized_fully_connected.per_tensor")
621+
def quantized_fully_connected_per_tensor_meta(
622+
src: torch.Tensor,
623+
weight: torch.Tensor,
624+
bias: torch.Tensor,
625+
in_zero_point: int,
626+
weight_zero_point: int,
598627
out_multiplier: int,
599628
out_shift: int,
600629
out_zero_point: int,
@@ -607,7 +636,7 @@ def quantized_fully_connected_meta(
607636
weight_size = list(weight.size())
608637
assert len(weight_size) == 2
609638
out_size[-1] = weight_size[0]
610-
return src.new_empty(out_size, dtype=torch.uint8)
639+
return src.new_empty(out_size, dtype=src.dtype)
611640

612641

613642
@register_fake("cadence::convolution")

0 commit comments

Comments
 (0)