diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index ac4417c79ae..e860a2bfcc6 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -188,7 +188,7 @@ def quantized_relu_meta( out_multiplier: torch.Tensor, out_shift: torch.Tensor, ) -> torch.Tensor: - return X.new_empty(X.size(), dtype=torch.uint8) + return X.new_empty(X.size(), dtype=X.dtype) @register_fake("cadence::quantized_matmul") diff --git a/backends/cadence/hifi/operators/dequantize_per_tensor.cpp b/backends/cadence/hifi/operators/dequantize_per_tensor.cpp index f025a1cc6f2..18381a26e0a 100644 --- a/backends/cadence/hifi/operators/dequantize_per_tensor.cpp +++ b/backends/cadence/hifi/operators/dequantize_per_tensor.cpp @@ -45,7 +45,10 @@ void dequantize_per_tensor_out( const int32_t* input_data = input.const_data_ptr(); dequantize(out_data, input_data, scale, zero_point, numel); } else { - ET_CHECK_MSG(false, "Unhandled input dtype %hhd", input.scalar_type()); + ET_CHECK_MSG( + false, + "Unhandled input dtype %hhd", + static_cast(input.scalar_type())); } } diff --git a/backends/cadence/hifi/operators/quantize_per_tensor.cpp b/backends/cadence/hifi/operators/quantize_per_tensor.cpp index 9cc84fffa38..c65d62968f5 100644 --- a/backends/cadence/hifi/operators/quantize_per_tensor.cpp +++ b/backends/cadence/hifi/operators/quantize_per_tensor.cpp @@ -49,7 +49,10 @@ void quantize_per_tensor_out( cadence::impl::HiFi::kernels::quantize( out_data, input_data, 1. / scale, zero_point, numel); } else { - ET_CHECK_MSG(false, "Unhandled input dtype %hhd", out.scalar_type()); + ET_CHECK_MSG( + false, + "Unhandled output dtype %hhd", + static_cast(out.scalar_type())); } } diff --git a/backends/cadence/reference/operators/quantized_conv_out.cpp b/backends/cadence/reference/operators/quantized_conv_out.cpp index b37c5884c11..de19f3ef43a 100644 --- a/backends/cadence/reference/operators/quantized_conv_out.cpp +++ b/backends/cadence/reference/operators/quantized_conv_out.cpp @@ -248,6 +248,11 @@ void quantized_conv_out( output_scale, (int8_t)output_zero_point, per_tensor_quantized); + } else { + ET_CHECK_MSG( + false, + "Unhandled input dtype %hhd", + static_cast(input.scalar_type())); } } diff --git a/backends/cadence/reference/operators/quantized_linear_out.cpp b/backends/cadence/reference/operators/quantized_linear_out.cpp index a02794c179c..7bb1bf6fb47 100644 --- a/backends/cadence/reference/operators/quantized_linear_out.cpp +++ b/backends/cadence/reference/operators/quantized_linear_out.cpp @@ -17,8 +17,8 @@ using executorch::aten::Tensor; using executorch::runtime::getLeadingDims; using executorch::runtime::KernelRuntimeContext; -void quantized_linear_out( - KernelRuntimeContext& ctx, +template +void inline _typed_quantized_linear( const Tensor& src, const Tensor& weight, const Tensor& bias, @@ -27,14 +27,11 @@ void quantized_linear_out( const Tensor& out_multiplier, const Tensor& out_shift, int64_t out_zero_point, - const executorch::aten::optional& offset, Tensor& out) { - // Assuming uint8_t for now, but needs to be updated for other quantization - // types - const uint8_t* __restrict__ src_data = src.const_data_ptr(); - const uint8_t* __restrict__ weight_data = weight.const_data_ptr(); + const T* __restrict__ src_data = src.const_data_ptr(); + const T* __restrict__ weight_data = weight.const_data_ptr(); const int32_t* __restrict__ bias_data = bias.const_data_ptr(); - uint8_t* __restrict__ out_data = out.mutable_data_ptr(); + T* __restrict__ out_data = out.mutable_data_ptr(); int32_t weight_zero_point = weight_zero_point_t.const_data_ptr()[0]; @@ -71,11 +68,53 @@ void quantized_linear_out( (weight_data[j * N + k] - weight_zero_point); } out_data[i * M + j] = - kernels::quantize(sum, out_scale, out_zero_point); + kernels::quantize(sum, out_scale, out_zero_point); } } } +void quantized_linear_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& src, + const Tensor& weight, + const Tensor& bias, + int64_t src_zero_point, + const Tensor& weight_zero_point_t, + const Tensor& out_multiplier, + const Tensor& out_shift, + int64_t out_zero_point, + __ET_UNUSED const executorch::aten::optional& offset, + Tensor& out) { + if (out.scalar_type() == executorch::aten::ScalarType::Byte) { + _typed_quantized_linear( + src, + weight, + bias, + src_zero_point, + weight_zero_point_t, + out_multiplier, + out_shift, + out_zero_point, + out); + } else if (out.scalar_type() == executorch::aten::ScalarType::Char) { + _typed_quantized_linear( + src, + weight, + bias, + src_zero_point, + weight_zero_point_t, + out_multiplier, + out_shift, + out_zero_point, + out); + } else { + ET_CHECK_MSG( + false, + "Unhandled input dtype %hhd", + static_cast(src.scalar_type())); + } +} + }; // namespace native }; // namespace reference }; // namespace impl diff --git a/backends/cadence/reference/operators/quantized_matmul_out.cpp b/backends/cadence/reference/operators/quantized_matmul_out.cpp index bf901105ea4..d12fc533e73 100644 --- a/backends/cadence/reference/operators/quantized_matmul_out.cpp +++ b/backends/cadence/reference/operators/quantized_matmul_out.cpp @@ -144,6 +144,11 @@ void quantized_matmul_out( out_zero_point, transposed, out); + } else { + ET_CHECK_MSG( + false, + "Unhandled input dtype %hhd", + static_cast(X.scalar_type())); } }