diff --git a/backends/cadence/aot/functions.yaml b/backends/cadence/aot/functions.yaml index 2e9e187168f..d8024c0245a 100644 --- a/backends/cadence/aot/functions.yaml +++ b/backends/cadence/aot/functions.yaml @@ -208,6 +208,12 @@ - arg_meta: null kernel_name: impl::generic::quantize_per_tensor_asym16u_out +- func: cadence::quantize_per_tensor_asym32s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::generic::quantize_per_tensor_asym32s_out + - func: cadence::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: @@ -238,6 +244,12 @@ - arg_meta: null kernel_name: impl::generic::dequantize_per_tensor_asym16u_out +- func: cadence::dequantize_per_tensor_asym32s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::generic::dequantize_per_tensor_asym32s_out + - func: cadence::quantized_conv2d_nchw.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null diff --git a/backends/cadence/aot/functions_hifi.yaml b/backends/cadence/aot/functions_hifi.yaml index c48aac8686a..bcab980abd6 100644 --- a/backends/cadence/aot/functions_hifi.yaml +++ b/backends/cadence/aot/functions_hifi.yaml @@ -308,6 +308,11 @@ - arg_meta: null kernel_name: impl::HiFi::quantize_per_tensor_asym16s_out +- func: cadence::quantize_per_tensor_asym32s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::HiFi::quantize_per_tensor_asym32s_out - func: cadence::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) variants: function @@ -339,6 +344,12 @@ - arg_meta: null kernel_name: impl::HiFi::dequantize_per_tensor_asym16u_out +- func: cadence::dequantize_per_tensor_asym32s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::HiFi::dequantize_per_tensor_asym16s_out + - func: cadence::quantized_conv2d_nchw.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 567d86af457..bd208d04739 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -56,6 +56,13 @@ "quantize_per_tensor_asym16u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)" ) +lib.define( + "quantize_per_tensor_asym32s(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)" +) +lib.define( + "quantize_per_tensor_asym32s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)" +) + lib.define( "dequantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)" ) @@ -87,6 +94,13 @@ "dequantize_per_tensor_asym16u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)" ) +lib.define( + "dequantize_per_tensor_asym32s(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)" +) +lib.define( + "dequantize_per_tensor_asym32s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)" +) + lib.define( "quantized_layer_norm(Tensor X, Tensor X_scale, Tensor X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point) -> (Tensor Y)" ) @@ -641,6 +655,18 @@ def quantize_per_tensor_asym16u_meta( return input.new_empty(input.size(), dtype=dtype) +@register_fake("cadence::quantize_per_tensor_asym32s") +def quantize_per_tensor_asym32s_meta( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + return input.new_empty(input.size(), dtype=dtype) + + @register_fake("cadence::dequantize_per_tensor") def dequantize_per_tensor_meta( input: torch.Tensor, @@ -701,6 +727,18 @@ def dequantize_per_tensor_asym16u_meta( return input.new_empty(input.size(), dtype=torch.float) +@register_fake("cadence::dequantize_per_tensor_asym32s") +def dequantize_per_tensor_asym32s_meta( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + return input.new_empty(input.size(), dtype=torch.float) + + @register_fake("cadence::quantized_add") def quantized_add_meta( X: torch.Tensor, diff --git a/backends/cadence/aot/type_dispatch.py b/backends/cadence/aot/type_dispatch.py index 97a25938e8d..37f753767e9 100644 --- a/backends/cadence/aot/type_dispatch.py +++ b/backends/cadence/aot/type_dispatch.py @@ -108,6 +108,7 @@ class CompileTimeTypeDispatchPass(ExportPass): (torch.uint8,): "asym8u", (torch.int16,): "asym16s", (torch.uint16,): "asym16s", + (torch.int32,): "asym32s", }, variant="default", is_quant_op=True, @@ -119,6 +120,7 @@ class CompileTimeTypeDispatchPass(ExportPass): (torch.uint8,): "asym8u", (torch.int16,): "asym16s", (torch.uint16,): "asym16s", + (torch.int32,): "asym32s", }, variant="default", ), diff --git a/backends/cadence/generic/kernels/kernels.cpp b/backends/cadence/generic/kernels/kernels.cpp index 568d8468af9..25e25cfa60a 100644 --- a/backends/cadence/generic/kernels/kernels.cpp +++ b/backends/cadence/generic/kernels/kernels.cpp @@ -73,6 +73,7 @@ typed_quantize_val(int8_t); typed_quantize_val(uint8_t); typed_quantize_val(int16_t); typed_quantize_val(uint16_t); +typed_quantize_val(int32_t); #undef typed_quantize_val #define typed_quantize_vec(dtype) \ @@ -86,6 +87,7 @@ typed_quantize_vec(int8_t); typed_quantize_vec(uint8_t); typed_quantize_vec(int16_t); typed_quantize_vec(uint16_t); +typed_quantize_vec(int32_t); #undef typed_quantize_vec #define typed_dequantize_val(dtype) \ @@ -94,6 +96,7 @@ typed_dequantize_val(int8_t); typed_dequantize_val(uint8_t); typed_dequantize_val(int16_t); typed_dequantize_val(uint16_t); +typed_dequantize_val(int32_t); #undef typed_dequantize_val #define typed_dequantize_vec(dtype) \ @@ -107,6 +110,7 @@ typed_dequantize_vec(int8_t); typed_dequantize_vec(uint8_t); typed_dequantize_vec(int16_t); typed_dequantize_vec(uint16_t); +typed_dequantize_vec(int32_t); #undef typed_dequantize_vec } // namespace kernels diff --git a/backends/cadence/generic/operators/dequantize_per_tensor.cpp b/backends/cadence/generic/operators/dequantize_per_tensor.cpp index aedc6e10309..ec05272da1b 100644 --- a/backends/cadence/generic/operators/dequantize_per_tensor.cpp +++ b/backends/cadence/generic/operators/dequantize_per_tensor.cpp @@ -44,6 +44,9 @@ Tensor& dequantize_per_tensor_out( } else if (input.scalar_type() == ScalarType::Short) { const int16_t* input_data = input.const_data_ptr(); dequantize(out_data, input_data, scale, zero_point, numel); + } else if (input.scalar_type() == ScalarType::Int) { + const int32_t* input_data = input.const_data_ptr(); + dequantize(out_data, input_data, scale, zero_point, numel); } else { ET_CHECK_MSG( false, @@ -117,6 +120,22 @@ Tensor& dequantize_per_tensor_asym16u_out( return out; } +Tensor& dequantize_per_tensor_asym32s_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + float* out_data = out.mutable_data_ptr(); + size_t numel = out.numel(); + const int32_t* input_data = input.const_data_ptr(); + dequantize(out_data, input_data, scale, zero_point, numel); + return out; +} + } // namespace native } // namespace generic } // namespace impl diff --git a/backends/cadence/generic/operators/quantize_per_tensor.cpp b/backends/cadence/generic/operators/quantize_per_tensor.cpp index f2a413be35d..8ce70d2b51d 100644 --- a/backends/cadence/generic/operators/quantize_per_tensor.cpp +++ b/backends/cadence/generic/operators/quantize_per_tensor.cpp @@ -46,6 +46,9 @@ Tensor& quantize_per_tensor_out( } else if (out.scalar_type() == ScalarType::Short) { int16_t* out_data = out.mutable_data_ptr(); quantize(out_data, input_data, 1. / scale, zero_point, numel); + } else if (out.scalar_type() == ScalarType::Int) { + int32_t* out_data = out.mutable_data_ptr(); + quantize(out_data, input_data, 1. / scale, zero_point, numel); } else { ET_CHECK_MSG( false, @@ -119,6 +122,22 @@ Tensor& quantize_per_tensor_asym16u_out( return out; } +Tensor& quantize_per_tensor_asym32s_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + const float* input_data = input.const_data_ptr(); + size_t numel = out.numel(); + int32_t* out_data = out.mutable_data_ptr(); + quantize(out_data, input_data, 1. / scale, zero_point, numel); + return out; +} + }; // namespace native }; // namespace generic }; // namespace impl diff --git a/backends/cadence/hifi/kernels/kernels.cpp b/backends/cadence/hifi/kernels/kernels.cpp index d9223d7bd18..237c605443f 100644 --- a/backends/cadence/hifi/kernels/kernels.cpp +++ b/backends/cadence/hifi/kernels/kernels.cpp @@ -127,6 +127,7 @@ typed_quantize_val(int8_t); typed_quantize_val(uint8_t); typed_quantize_val(int16_t); typed_quantize_val(uint16_t); +typed_quantize_val(int32_t); #undef typed_quantize_val #define typed_quantize_vec(dtype) \ @@ -150,6 +151,7 @@ typed_dequantize_val(int8_t); typed_dequantize_val(uint8_t); typed_dequantize_val(int16_t); typed_dequantize_val(uint16_t); +typed_dequantize_val(int32_t); #undef typed_dequantize_val #define typed_dequantize_vec(dtype) \ diff --git a/backends/cadence/hifi/operators/op_dequantize_per_tensor.cpp b/backends/cadence/hifi/operators/op_dequantize_per_tensor.cpp index 317e7ed8ef9..30ce938e24d 100644 --- a/backends/cadence/hifi/operators/op_dequantize_per_tensor.cpp +++ b/backends/cadence/hifi/operators/op_dequantize_per_tensor.cpp @@ -45,6 +45,9 @@ void dequantize_per_tensor_out( input.scalar_type() == ScalarType::UInt16) { const uint16_t* input_data = input.const_data_ptr(); dequantize(out_data, input_data, scale, zero_point, numel); + } else if (input.scalar_type() == ScalarType::Int) { + const int32_t* input_data = input.const_data_ptr(); + dequantize(out_data, input_data, scale, zero_point, numel); } else { ET_CHECK_MSG( false, @@ -98,6 +101,21 @@ void dequantize_per_tensor_asym16u_out( dequantize(out_data, input_data, scale, zero_point, numel); } +void dequantize_per_tensor_asym32s_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + float* out_data = out.mutable_data_ptr(); + size_t numel = out.numel(); + const int32_t* input_data = input.const_data_ptr(); + dequantize(out_data, input_data, scale, zero_point, numel); +} + } // namespace native } // namespace HiFi } // namespace impl diff --git a/backends/cadence/hifi/operators/op_quantize_per_tensor.cpp b/backends/cadence/hifi/operators/op_quantize_per_tensor.cpp index 9bc3d48699e..579a4533057 100644 --- a/backends/cadence/hifi/operators/op_quantize_per_tensor.cpp +++ b/backends/cadence/hifi/operators/op_quantize_per_tensor.cpp @@ -108,6 +108,9 @@ void quantize_per_tensor_out( out.scalar_type() == ScalarType::UInt16) { uint16_t* out_data = out.mutable_data_ptr(); quantize(out_data, input_data, 1. / scale, zero_point, numel); + } else if (out.scalar_type() == ScalarType::Int) { + int32_t* out_data = out.mutable_data_ptr(); + quantize(out_data, input_data, 1. / scale, zero_point, numel); } else { ET_KERNEL_CHECK_MSG( ctx, @@ -164,6 +167,21 @@ void quantize_per_tensor_asym16u_out( quantize(out_data, input_data, 1. / scale, zero_point, numel); } +void quantize_per_tensor_asym32s_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + const float* input_data = input.const_data_ptr(); + size_t numel = out.numel(); + int32_t* out_data = out.mutable_data_ptr(); + quantize(out_data, input_data, 1. / scale, zero_point, numel); +} + }; // namespace native }; // namespace HiFi }; // namespace impl