|
2 | 2 | // Auto Generated Code for FastGeluOpPackage |
3 | 3 | //============================================================================== |
4 | 4 |
|
| 5 | +#include <algorithm> |
5 | 6 | #include <cmath> |
6 | 7 | #include "HTP/core/constraints.h" |
7 | 8 | #include "HTP/core/op_package_feature_support.h" |
@@ -80,27 +81,179 @@ DEF_PACKAGE_OP((fastgeluImpl<Tensor>), "FastGelu") |
80 | 81 |
|
81 | 82 | /* execute functions for ops */ |
82 | 83 |
|
| 84 | +// template <typename TensorType> |
| 85 | +// GraphStatus fastgeluImpl(TensorType& y, const TensorType& x) { |
| 86 | +// const uint32_t numElements = x.total_storage_elements(); |
| 87 | + |
| 88 | +// if (y.total_storage_elements() != numElements) { |
| 89 | +// return GraphStatus::ErrorFatal; |
| 90 | +// } |
| 91 | + |
| 92 | +// const float kAlpha = 0.7978845608f; // sqrt(2/pi) |
| 93 | +// const float kCoeff = 0.044715f; |
| 94 | + |
| 95 | +// float* yData = reinterpret_cast<float*>(y.raw_data()); |
| 96 | +// const float* xData = reinterpret_cast<const float*>(x.raw_data_const()); |
| 97 | + |
| 98 | +// for (uint32_t i = 0; i < numElements; ++i) { |
| 99 | +// const float v = xData[i]; |
| 100 | +// const float inner = kAlpha * (v + kCoeff * v * v * v); |
| 101 | +// yData[i] = 0.5f * v * (1.0f + std::tanh(inner)); |
| 102 | +// } |
| 103 | + |
| 104 | +// return GraphStatus::Success; |
| 105 | +// } |
| 106 | + |
83 | 107 | template <typename TensorType> |
84 | 108 | GraphStatus fastgeluImpl(TensorType& y, const TensorType& x) { |
85 | | - const uint32_t numElements = x.total_storage_elements(); |
| 109 | + const uint32_t N = x.total_storage_elements(); |
86 | 110 |
|
87 | | - if (y.total_storage_elements() != numElements) { |
| 111 | + if (y.total_storage_elements() != N) { |
88 | 112 | return GraphStatus::ErrorFatal; |
89 | 113 | } |
90 | 114 |
|
91 | | - const float kAlpha = 0.7978845608f; // sqrt(2/pi) |
92 | | - const float kCoeff = 0.044715f; |
| 115 | + const auto in_info = x.get_dtype_intfc(); |
| 116 | + const auto out_info = y.get_dtype_intfc(); |
93 | 117 |
|
94 | | - float* yData = reinterpret_cast<float*>(y.raw_data()); |
95 | | - const float* xData = reinterpret_cast<const float*>(x.raw_data_const()); |
96 | | - |
97 | | - for (uint32_t i = 0; i < numElements; ++i) { |
98 | | - const float v = xData[i]; |
99 | | - const float inner = kAlpha * (v + kCoeff * v * v * v); |
100 | | - yData[i] = 0.5f * v * (1.0f + std::tanh(inner)); |
| 118 | + if (in_info.dtype != DType::Float32 || in_info.dtype != DType::QUInt8) { |
| 119 | + return GraphStatus::ErrorPrecision; |
101 | 120 | } |
| 121 | + if (in_info.dtype == DType::Float32 && out_info.dtype == DType::Float32) { |
| 122 | + const float* xData = static_cast<const float*>(x.raw_data_const()); |
| 123 | + float* yData = static_cast<float*>(y.raw_data()); |
| 124 | + |
| 125 | + // --- Temporary FP16 buffers --- |
| 126 | + std::vector<Float16> tmp_in(N); |
| 127 | + std::vector<Float16> tmp_out(N); |
| 128 | + |
| 129 | + for (uint32_t i = 0; i < N; ++i) { |
| 130 | + tmp_in[i] = static_cast<Float16>(xData[i]); |
| 131 | + } |
| 132 | + |
| 133 | +#ifdef __hexagon__ |
| 134 | + union { |
| 135 | + Float16 f; |
| 136 | + uint16_t b; |
| 137 | + } kAlpha = {(Float16)0.7978845608f}; // sqrt(2/pi) |
| 138 | + union { |
| 139 | + Float16 f; |
| 140 | + uint16_t b; |
| 141 | + } kCoeff = {(Float16)0.044715f}; |
| 142 | + union { |
| 143 | + Float16 f; |
| 144 | + uint16_t b; |
| 145 | + } kHalf = {(Float16)0.5f}; |
| 146 | + union { |
| 147 | + Float16 f; |
| 148 | + uint16_t b; |
| 149 | + } kOne = {(Float16)1.0f}; |
| 150 | + union { |
| 151 | + Float16 f; |
| 152 | + uint16_t b; |
| 153 | + } k27 = {(Float16)27.0f}; |
| 154 | + union { |
| 155 | + Float16 f; |
| 156 | + uint16_t b; |
| 157 | + } kInv27 = {(Float16)(1.0f / 27.0f)}; |
| 158 | + union { |
| 159 | + Float16 f; |
| 160 | + uint16_t b; |
| 161 | + } kOne3 = {(Float16)(1.0f / 3.0f)}; |
| 162 | + union { |
| 163 | + Float16 f; |
| 164 | + uint16_t b; |
| 165 | + } kOne9 = {(Float16)(1.0f / 9.0f)}; |
| 166 | + |
| 167 | + HVX_Vector v_alpha = Q6_Vh_vsplat_R(kAlpha.b); |
| 168 | + HVX_Vector v_coeff = Q6_Vh_vsplat_R(kCoeff.b); |
| 169 | + HVX_Vector v_half = Q6_Vh_vsplat_R(kHalf.b); |
| 170 | + HVX_Vector v_one = Q6_Vh_vsplat_R(kOne.b); |
| 171 | + HVX_Vector v_27 = Q6_Vh_vsplat_R(k27.b); |
| 172 | + HVX_Vector v_inv27 = Q6_Vh_vsplat_R(kInv27.b); |
| 173 | + HVX_Vector v_1_3 = Q6_Vh_vsplat_R(kOne3.b); |
| 174 | + HVX_Vector v_1_9 = Q6_Vh_vsplat_R(kOne9.b); |
| 175 | + |
| 176 | + const int VBYTES = 128; |
| 177 | + const int ELEMS = VBYTES / sizeof(Float16); // 64 |
102 | 178 |
|
103 | | - return GraphStatus::Success; |
| 179 | + for (uint32_t i = 0; i < N; i += ELEMS) { |
| 180 | + HVX_Vector vx = q6op_V_vldu_A(&tmp_in[i]); // x |
| 181 | + HVX_Vector vx2 = Q6_Vhf_vmpy_VhfVhf(vx, vx); // x^2 |
| 182 | + HVX_Vector vx3 = Q6_Vhf_vmpy_VhfVhf(vx2, vx); // x^3 |
| 183 | + |
| 184 | + // z = α * (x + c*x^3) |
| 185 | + HVX_Vector vcx3 = Q6_Vhf_vmpy_VhfVhf(vx3, v_coeff); |
| 186 | + HVX_Vector vsum = Q6_Vhf_vadd_VhfVhf(vx, vcx3); |
| 187 | + HVX_Vector vz = Q6_Vhf_vmpy_VhfVhf(vsum, v_alpha); |
| 188 | + |
| 189 | + // z^2, z^4 |
| 190 | + HVX_Vector vz2 = Q6_Vhf_vmpy_VhfVhf(vz, vz); |
| 191 | + HVX_Vector vz4 = Q6_Vhf_vmpy_VhfVhf(vz2, vz2); |
| 192 | + |
| 193 | + // inv_den ≈ (1/27) * (1 - (1/3) z^2 + (1/9) z^4) |
| 194 | + HVX_Vector term1 = Q6_Vhf_vmpy_VhfVhf(vz2, v_1_3); // (1/3) z^2 |
| 195 | + HVX_Vector one_m_t = Q6_Vhf_vsub_VhfVhf(v_one, term1); // 1 - (1/3) z^2 |
| 196 | + HVX_Vector term2 = Q6_Vhf_vmpy_VhfVhf(vz4, v_1_9); // (1/9) z^4 |
| 197 | + HVX_Vector poly = |
| 198 | + Q6_Vhf_vadd_VhfVhf(one_m_t, term2); // 1 - 1/3 z^2 + 1/9 z^4 |
| 199 | + HVX_Vector inv_den = Q6_Vhf_vmpy_VhfVhf(poly, v_inv27); // * (1/27) |
| 200 | + |
| 201 | + // num = z * (27 + z^2) = 27z + z^3 |
| 202 | + HVX_Vector z3 = Q6_Vhf_vmpy_VhfVhf(vz2, vz); |
| 203 | + HVX_Vector t27z = Q6_Vhf_vmpy_VhfVhf(vz, v_27); |
| 204 | + HVX_Vector num = Q6_Vhf_vadd_VhfVhf(t27z, z3); |
| 205 | + |
| 206 | + // tanh(z) ≈ num * inv_den |
| 207 | + HVX_Vector vtanh = Q6_Vhf_vmpy_VhfVhf(num, inv_den); |
| 208 | + |
| 209 | + // y = 0.5 * x * (1 + tanh) |
| 210 | + HVX_Vector one_plus_tanh = Q6_Vhf_vadd_VhfVhf(v_one, vtanh); |
| 211 | + HVX_Vector t = Q6_Vhf_vmpy_VhfVhf(vx, one_plus_tanh); |
| 212 | + HVX_Vector vy = Q6_Vhf_vmpy_VhfVhf(t, v_half); |
| 213 | + |
| 214 | + q6op_vstu_AV(&tmp_out[i], vy); |
| 215 | + } |
| 216 | +#else |
| 217 | + // Scalar fallback |
| 218 | + for (uint32_t i = 0; i < N; ++i) { |
| 219 | + const float v = xData[i]; |
| 220 | + const float inner = 0.7978845608f * (v + 0.044715f * v * v * v); |
| 221 | + yData[i] = 0.5f * v * (1.0f + std::tanh(inner)); |
| 222 | + } |
| 223 | +#endif |
| 224 | + |
| 225 | + for (uint32_t i = 0; i < N; ++i) { |
| 226 | + yData[i] = static_cast<float>(tmp_out[i]); |
| 227 | + } |
| 228 | + return GraphStatus::Success; |
| 229 | + } else if (in_info.dtype == DType::QUInt8) { |
| 230 | + const uint8_t* xData = static_cast<const uint8_t*>(x.raw_data_const()); |
| 231 | + uint8_t* yData = static_cast<uint8_t*>(y.raw_data()); |
| 232 | + |
| 233 | + const float x_scale = in_info.scale; |
| 234 | + const float y_scale = out_info.scale; |
| 235 | + const int32_t x_zero = in_info.offset; |
| 236 | + const int32_t y_zero = out_info.offset; |
| 237 | + |
| 238 | + alignas(128) static uint8_t lut[256]; |
| 239 | + static bool lut_init = false; |
| 240 | + if (!lut_init) { |
| 241 | + for (int i = 0; i < 256; ++i) { |
| 242 | + float x_f = (i - x_zero) * x_scale; |
| 243 | + float inner = 0.7978845608f * (x_f + 0.044715f * x_f * x_f * x_f); |
| 244 | + float y_f = 0.5f * x_f * (1.0f + std::tanh(inner)); |
| 245 | + int y_q = static_cast<int>(std::round(y_f / y_scale)) + y_zero; |
| 246 | + lut[i] = static_cast<uint8_t>(std::clamp(y_q, 0, 255)); |
| 247 | + } |
| 248 | + lut_init = true; |
| 249 | + } |
| 250 | + for (uint32_t i = 0; i < N; ++i) { |
| 251 | + yData[i] = lut[xData[i]]; |
| 252 | + } |
| 253 | + return GraphStatus::Success; |
| 254 | + } else { |
| 255 | + return GraphStatus::ErrorFatal; |
| 256 | + } |
104 | 257 | } |
105 | 258 |
|
106 | 259 | __attribute__((unused)) static float fastgeluCostFunc(const Op* op) { |
|
0 commit comments