diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index d47ea3f21af..fce6ce5736b 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -66,6 +66,12 @@ lib.define( "quantized_conv.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)" ) +lib.define( + "quantized_conv.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False) -> (Tensor Z)" +) +lib.define( + "quantized_conv.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)" +) lib.define( "quantized_matmul(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False) -> (Tensor Z)" @@ -171,6 +177,54 @@ def quantized_conv_meta( return input.new_empty(output_size, dtype=input.dtype) +@register_fake("cadence::quantized_conv.per_tensor") +def quantized_conv_per_tensor_meta( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: Tuple[int], + padding: Tuple[int], + dilation: Tuple[int], + groups: int, + in_zero_point: int, + weight_zero_point: int, + bias_scale: float, + output_scale: float, + output_zero_point: int, + out_multiplier: int, + out_shift: int, + channel_last: bool = False, +) -> torch.Tensor: + if channel_last: + out_channels, *kernel_size, _ = weight.shape + else: + out_channels, _, *kernel_size = weight.shape + + in_size = input.shape + # Assert that the input tensor has at least 3 dimensions, and at most 6 + assert len(in_size) > 2 + assert len(in_size) < 6 + + # Compute the output tensor size + output_size = ( + get_conv1d_output_size( + in_size, + out_channels, + stride[1], + padding[1], + dilation[1], + kernel_size[0], + channel_last, + ) + if len(in_size) == 3 + else get_conv2d_output_size( + in_size, out_channels, stride, padding, dilation, kernel_size, channel_last + ) + ) + + return input.new_empty(output_size, dtype=input.dtype) + + @register_fake("cadence::quantized_layer_norm") def quantized_layer_norm_meta( input: torch.Tensor,