diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index fce6ce5736b..5e852b369d1 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -50,7 +50,11 @@ "quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( - "cadence::quantized_linear.per_tensor_out(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" + "quantized_linear.per_tensor_out(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.define( + "quantized_linear.per_tensor(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, " + "SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset) -> Tensor" ) lib.define( @@ -129,6 +133,28 @@ def quantized_linear_meta( return src.new_empty(out_size, dtype=src.dtype) +@register_fake("cadence::quantized_linear.per_tensor") +def quantized_linear_per_tensor_meta( + src: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + in_zero_point: torch.SymInt, + weight_zero_point: torch.SymInt, + out_multiplier: torch.SymInt, + out_shift: torch.SymInt, + out_zero_point: torch.SymInt, + offset: Optional[torch.Tensor], +) -> torch.Tensor: + # src comes in shape [leading_dims, in_dim] + # weight comes in shape [out_dim, in_dim] + # output comes in empty with shape [leading_dims, out_dim] + out_size = list(src.size()) + weight_size = list(weight.size()) + assert len(weight_size) == 2 + out_size[-1] = weight_size[0] + return src.new_empty(out_size, dtype=src.dtype) + + @register_fake("cadence::quantized_conv") def quantized_conv_meta( input: torch.Tensor,