|
7 | 7 | */ |
8 | 8 |
|
9 | 9 | #include <executorch/backends/cadence/hifi/kernels/kernels.h> |
| 10 | +#include <executorch/backends/cadence/hifi/operators/operators.h> |
10 | 11 | #include <executorch/runtime/kernel/kernel_includes.h> |
11 | 12 | #include <algorithm> |
12 | 13 | #include <cmath> |
13 | 14 | #include <tuple> |
14 | 15 |
|
15 | | -using executorch::aten::Tensor; |
16 | | -using executorch::runtime::getLeadingDims; |
17 | | -using executorch::runtime::KernelRuntimeContext; |
| 16 | +using ::executorch::aten::IntArrayRef; |
| 17 | +using ::executorch::aten::ScalarType; |
| 18 | +using ::executorch::aten::Tensor; |
| 19 | +using ::executorch::runtime::getLeadingDims; |
| 20 | +using ::executorch::runtime::KernelRuntimeContext; |
18 | 21 |
|
19 | 22 | namespace cadence { |
20 | 23 | namespace impl { |
@@ -77,10 +80,10 @@ void quantized_layer_norm_( |
77 | 80 | for (size_t j = 0; j < last_dim; ++j) { |
78 | 81 | // Since X is quantized, we dequantize it, compute fp32 result, and |
79 | 82 | // quantize the result to an int8/uint8 value. |
80 | | - float val = cadence::impl::HiFi::kernels::dequantize<T>( |
| 83 | + float val = ::cadence::impl::HiFi::kernels::dequantize<T>( |
81 | 84 | x[j], input_scale, input_zero_point); |
82 | 85 | val = (val - mean) * inv_std * weight_data[j] + bias_data[j]; |
83 | | - y[j] = cadence::impl::HiFi::kernels::quantize<T>( |
| 86 | + y[j] = ::cadence::impl::HiFi::kernels::quantize<T>( |
84 | 87 | val, output_inv_scale, output_zero_point); |
85 | 88 | } |
86 | 89 | } |
@@ -121,38 +124,37 @@ void quantized_layer_norm_out( |
121 | 124 | const Tensor& input, |
122 | 125 | const Tensor& in_scale, |
123 | 126 | const Tensor& in_zero_point, |
124 | | - const executorch::aten::IntArrayRef normalized_shape, |
| 127 | + __ET_UNUSED const IntArrayRef normalized_shape, |
125 | 128 | const Tensor& weight, |
126 | 129 | const Tensor& bias, |
127 | 130 | double eps, |
128 | 131 | double output_scale, |
129 | 132 | int64_t output_zero_point, |
130 | 133 | Tensor& out) { |
131 | | - if (input.scalar_type() == executorch::aten::ScalarType::Byte) { |
132 | | - quantized_layer_norm_<uint8_t>( |
133 | | - input, |
134 | | - in_scale, |
135 | | - in_zero_point, |
136 | | - weight, |
137 | | - bias, |
138 | | - eps, |
139 | | - output_scale, |
140 | | - output_zero_point, |
141 | | - out); |
142 | | - } else if (input.scalar_type() == executorch::aten::ScalarType::Char) { |
143 | | - quantized_layer_norm_<int8_t>( |
144 | | - input, |
145 | | - in_scale, |
146 | | - in_zero_point, |
147 | | - weight, |
148 | | - bias, |
149 | | - eps, |
150 | | - output_scale, |
151 | | - output_zero_point, |
152 | | - out); |
153 | | - } else { |
154 | | - ET_CHECK_MSG(false, "Unhandled input dtype %hhd", input.scalar_type()); |
| 134 | +#define typed_quantized_layer_norm(ctype, dtype) \ |
| 135 | + case ScalarType::dtype: { \ |
| 136 | + quantized_layer_norm_<ctype>( \ |
| 137 | + input, \ |
| 138 | + in_scale, \ |
| 139 | + in_zero_point, \ |
| 140 | + weight, \ |
| 141 | + bias, \ |
| 142 | + eps, \ |
| 143 | + output_scale, \ |
| 144 | + output_zero_point, \ |
| 145 | + out); \ |
| 146 | + break; \ |
155 | 147 | } |
| 148 | + |
| 149 | + ScalarType dtype = input.scalar_type(); |
| 150 | + switch (dtype) { |
| 151 | + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_layer_norm) |
| 152 | + default: |
| 153 | + ET_DCHECK_MSG( |
| 154 | + false, "Unhandled dtype %s", torch::executor::toString(dtype)); |
| 155 | + } |
| 156 | + |
| 157 | +#undef typed_quantized_layer_norm |
156 | 158 | } |
157 | 159 |
|
158 | 160 | }; // namespace native |
|
0 commit comments