diff --git a/extension/llm/custom_ops/spinquant/fast_hadamard_transform.cpp b/extension/llm/custom_ops/spinquant/fast_hadamard_transform.cpp index a9f6b7de753..dd34e8da852 100644 --- a/extension/llm/custom_ops/spinquant/fast_hadamard_transform.cpp +++ b/extension/llm/custom_ops/spinquant/fast_hadamard_transform.cpp @@ -11,6 +11,43 @@ #include namespace executorch { +namespace { +// Normalization step: divide by sqrt(1 << log2_vec_size). Similar +// to fast_sqrt above, if N is even, then the maximum-precision way +// to do this is right-shift by log2_vec_size / 2. If N is odd, we +// still do the right-shift, and then we have an extra division by +// sqrt(2) that we perform by making use of a sufficiently accurate +// rational approximation. Our initial idea was to divide by sqrt(2) +// by adjusting the quantization scale, but that would cause this +// function to tend to increase the magnitude of the elements of +// vec, which would resulting in clipping and therefore accuracy +// loss, especially compounded over 30+ transformer layers. +void quantized_normalize_after_fht( + const int32_t* tmp, + int16_t* out, + int log2_vec_size, + int vec_size) { + const int log2_sqrt_vec_size = log2_vec_size / 2; + constexpr int32_t qmin = -(1 << 15) + 1; + constexpr int32_t qmax = -qmin; + if (log2_vec_size % 2 != 0) { + // 408 / 577 - 1.0 / sqrt(2) ~= 1.062e-0.6, which should be close enough. + static const int32_t inv_sqrt_2_numerator = 408; + static const int32_t inv_sqrt_2_denominator = 577; + for (int ii = 0; ii < vec_size; ++ii) { + const auto val_over_sqrt_vec_size = + (tmp[ii] * inv_sqrt_2_numerator / inv_sqrt_2_denominator) >> + log2_sqrt_vec_size; + out[ii] = std::clamp(val_over_sqrt_vec_size, qmin, qmax); + } + } else { + for (int ii = 0; ii < vec_size; ++ii) { + out[ii] = std::clamp(tmp[ii] >> log2_sqrt_vec_size, qmin, qmax); + } + } +} +} // namespace + void fast_hadamard_transform_symmetric_quantized_s16( int16_t* vec, int log2_vec_size) { @@ -27,7 +64,7 @@ void fast_hadamard_transform_symmetric_quantized_s16( auto tmp = std::make_unique(vec_size); std::copy(vec, vec + vec_size, tmp.get()); - // Per the function-level comment in the header, we can ignore the + // Per the function-level comment above, we can ignore the // quantization scale, so we just delegate to the usual unnormalized // implementation. // NOTE: if we need this to be fast on CPU, we can use FFHT to @@ -35,34 +72,30 @@ void fast_hadamard_transform_symmetric_quantized_s16( internal::fast_hadamard_transform_unnormalized_simple_impl( tmp.get(), log2_vec_size); - // Normalization step: divide by sqrt(1 << log2_vec_size). Similar - // to fast_sqrt, if N is even, then the maximum-precision way - // to do this is right-shift by log2_vec_size / 2. If N is odd, we - // still do the right-shift, and then we have an extra division by - // sqrt(2) that we perform by making use of a sufficiently accurate - // rational approximation. (Our initial idea was to divide by sqrt(2) - // by adjusting the quantization scale, but that would cause this - // function to tend to increase the magnitude of the elements of - // vec, which would resulting in clipping and therefore accuracy - // loss, especially compounded over 30+ transformer layers.) - const int log2_sqrt_vec_size = log2_vec_size / 2; - constexpr int32_t qmin = -(1 << 15) + 1; - constexpr int32_t qmax = -qmin; - if (log2_vec_size % 2 != 0) { - // 408 / 577 - 1.0 / sqrt(2) ~= 1.062e-0.6, which should be close enough. - static const int32_t inv_sqrt_2_numerator = 408; - static const int32_t inv_sqrt_2_denominator = 577; - for (int ii = 0; ii < vec_size; ++ii) { - const auto val_over_sqrt_vec_size = - (tmp[ii] * inv_sqrt_2_numerator / inv_sqrt_2_denominator) >> - log2_sqrt_vec_size; - vec[ii] = std::clamp(val_over_sqrt_vec_size, qmin, qmax); - } - } else { - for (int ii = 0; ii < vec_size; ++ii) { - vec[ii] = std::clamp(tmp[ii] >> log2_sqrt_vec_size, qmin, qmax); - } + quantized_normalize_after_fht(tmp.get(), vec, log2_vec_size, vec_size); +} + +void fast_hadamard_transform_symmetric_quantized_s16_28N( + int16_t* vec, + int log2_vec_size) { + if (log2_vec_size == 0) { + return; } - return; + const int vec_size = (1 << log2_vec_size); + + auto tmp = std::make_unique(vec_size * 28); + std::copy(vec, vec + vec_size * 28, tmp.get()); + + for (int ii = 0; ii < 28; ++ii) { + internal::fast_hadamard_transform_unnormalized_simple_impl( + &tmp[ii * vec_size], log2_vec_size); + } + + for (int ii = 0; ii < vec_size; ++ii) { + hadamard_mult_28_strided(&tmp[ii], vec_size); + } + + quantized_normalize_after_fht(tmp.get(), vec, log2_vec_size, vec_size * 28); } + } // namespace executorch diff --git a/extension/llm/custom_ops/spinquant/fast_hadamard_transform.h b/extension/llm/custom_ops/spinquant/fast_hadamard_transform.h index 4cddf7d4807..4f8d205dbd4 100644 --- a/extension/llm/custom_ops/spinquant/fast_hadamard_transform.h +++ b/extension/llm/custom_ops/spinquant/fast_hadamard_transform.h @@ -112,4 +112,11 @@ void fast_hadamard_transform_28N(T* vec, int log2_vec_size) { } } +// We don't need the quantization scale; see the function-level +// comment on fast_hadamard_transform_symmetric_quantized_s16 for +// details. +void fast_hadamard_transform_symmetric_quantized_s16_28N( + int16_t* vec, + int log2_vec_size); + } // namespace executorch