Skip to content

Commit 650efcd

Browse files
zonglinpengfacebook-github-bot
authored andcommitted
register quantized_linear.per_tensor in lib (#6563)
Summary: cont of previous diff Reviewed By: hsharma35 Differential Revision: D65104400
1 parent 735e019 commit 650efcd

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,11 @@
5050
"quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
5151
)
5252
lib.define(
53-
"cadence::quantized_linear.per_tensor_out(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
53+
"quantized_linear.per_tensor_out(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
54+
)
55+
lib.define(
56+
"quantized_linear.per_tensor(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, "
57+
"SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset) -> Tensor"
5458
)
5559

5660
lib.define(
@@ -123,6 +127,28 @@ def quantized_linear_meta(
123127
return src.new_empty(out_size, dtype=src.dtype)
124128

125129

130+
@register_fake("cadence::quantized_linear.per_tensor")
131+
def quantized_linear_per_tensor_meta(
132+
src: torch.Tensor,
133+
weight: torch.Tensor,
134+
bias: torch.Tensor,
135+
in_zero_point: torch.SymInt,
136+
weight_zero_point: torch.SymInt,
137+
out_multiplier: torch.SymInt,
138+
out_shift: torch.SymInt,
139+
out_zero_point: torch.SymInt,
140+
offset: Optional[torch.Tensor],
141+
) -> torch.Tensor:
142+
# src comes in shape [leading_dims, in_dim]
143+
# weight comes in shape [out_dim, in_dim]
144+
# output comes in empty with shape [leading_dims, out_dim]
145+
out_size = list(src.size())
146+
weight_size = list(weight.size())
147+
assert len(weight_size) == 2
148+
out_size[-1] = weight_size[0]
149+
return src.new_empty(out_size, dtype=src.dtype)
150+
151+
126152
@register_fake("cadence::quantized_conv")
127153
def quantized_conv_meta(
128154
input: torch.Tensor,

0 commit comments

Comments
 (0)