|
5 | 5 | #include <executorch/backends/cadence/reference/kernels/kernels.h> |
6 | 6 | #include <executorch/backends/cadence/reference/operators/operators.h> |
7 | 7 |
|
8 | | -// Generate kernels that perform elementwise arithmetic on two quantized |
9 | | -// tensors. The tensors are either the same size, or the second tensor is a |
10 | | -// scalar. |
11 | | -#define DECLARE_POINTWISE_TENSOR_QUANTIZED_BINARY_OP(BINARY_FUNC_NAME, OP) \ |
12 | | - template <typename T> \ |
13 | | - void BINARY_FUNC_NAME( \ |
14 | | - const ::executorch::aten::Tensor& X, \ |
15 | | - float X_scale, \ |
16 | | - int32_t X_zero_point, \ |
17 | | - const ::executorch::aten::Tensor& Y, \ |
18 | | - float Y_scale, \ |
19 | | - int32_t Y_zero_point, \ |
20 | | - float out_scale, \ |
21 | | - int32_t out_zero_point, \ |
22 | | - ::executorch::aten::Tensor& out) { \ |
23 | | - const T* __restrict__ X_data = X.const_data_ptr<T>(); \ |
24 | | - const T* __restrict__ Y_data = Y.const_data_ptr<T>(); \ |
25 | | - T* __restrict__ out_data = out.mutable_data_ptr<T>(); \ |
26 | | - size_t Y_numel = Y.numel(); \ |
27 | | - size_t X_numel = X.numel(); \ |
28 | | - float inv_out_scale = 1.0f / out_scale; \ |
29 | | - /* Tensor that has the same element of X */ \ |
30 | | - if (Y_numel == X_numel) { \ |
31 | | - for (size_t i = 0; i < X_numel; ++i) { \ |
32 | | - float x = kernels::dequantize<T>(X_data[i], X_scale, X_zero_point); \ |
33 | | - float y = kernels::dequantize<T>(Y_data[i], Y_scale, Y_zero_point); \ |
34 | | - float z = x OP y; \ |
35 | | - out_data[i] = kernels::quantize<T>(z, inv_out_scale, out_zero_point); \ |
36 | | - } \ |
37 | | - } /* if Y is a scalar Tensor */ \ |
38 | | - else if (Y_numel == 1) { \ |
39 | | - float y = kernels::dequantize<T>(Y_data[0], Y_scale, Y_zero_point); \ |
40 | | - for (size_t i = 0; i < X_numel; ++i) { \ |
41 | | - float x = kernels::dequantize<T>(X_data[i], X_scale, X_zero_point); \ |
42 | | - float z = x OP y; \ |
43 | | - out_data[i] = kernels::quantize<T>(z, inv_out_scale, out_zero_point); \ |
44 | | - } \ |
45 | | - } /* other broadcasting cases */ \ |
46 | | - else { \ |
47 | | - ET_DCHECK_MSG(false, "Unsupported broadcasting"); \ |
48 | | - } \ |
49 | | - } |
50 | | - |
51 | 8 | template <typename T> |
52 | 9 | inline __attribute__((always_inline)) void quantized_linear_per_tensor_( |
53 | 10 | const ::executorch::aten::Tensor& src, |
|
0 commit comments