Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 64 additions & 29 deletions extension/llm/custom_ops/spinquant/fast_hadamard_transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,41 @@ void normalize_after_fht(T* out, int log2_vec_size) {
}
}

// Normalization step: divide by sqrt(1 << log2_vec_size). Similar
// to fast_sqrt above, if N is even, then the maximum-precision way
// to do this is right-shift by log2_vec_size / 2. If N is odd, we
// still do the right-shift, and then we have an extra division by
// sqrt(2) that we perform by making use of a sufficiently accurate
// rational approximation. Our initial idea was to divide by sqrt(2)
// by adjusting the quantization scale, but that would cause this
// function to tend to increase the magnitude of the elements of
// vec, which would resulting in clipping and therefore accuracy
// loss, especially compounded over 30+ transformer layers.
void quantized_normalize_after_fht(
const int32_t* tmp,
int16_t* out,
int log2_vec_size,
int vec_size) {
const int log2_sqrt_vec_size = log2_vec_size / 2;
constexpr int32_t qmin = -(1 << 15) + 1;
constexpr int32_t qmax = -qmin;
if (log2_vec_size % 2 != 0) {
// 408 / 577 - 1.0 / sqrt(2) ~= 1.062e-0.6, which should be close enough.
static const int32_t inv_sqrt_2_numerator = 408;
static const int32_t inv_sqrt_2_denominator = 577;
for (int ii = 0; ii < vec_size; ++ii) {
const auto val_over_sqrt_vec_size =
(tmp[ii] * inv_sqrt_2_numerator / inv_sqrt_2_denominator) >>
log2_sqrt_vec_size;
out[ii] = std::clamp(val_over_sqrt_vec_size, qmin, qmax);
}
} else {
for (int ii = 0; ii < vec_size; ++ii) {
out[ii] = std::clamp(tmp[ii] >> log2_sqrt_vec_size, qmin, qmax);
}
}
}

template <typename T>
void fast_hadamard_transform_unnormalized_simple_impl(
T* vec,
Expand Down Expand Up @@ -115,35 +150,8 @@ void fast_hadamard_transform_symmetric_quantized_s16(
internal::fast_hadamard_transform_unnormalized_simple_impl(
tmp.get(), log2_vec_size);

// Normalization step: divide by sqrt(1 << log2_vec_size). Similar
// to fast_sqrt above, if N is even, then the maximum-precision way
// to do this is right-shift by log2_vec_size / 2. If N is odd, we
// still do the right-shift, and then we have an extra division by
// sqrt(2) that we perform by making use of a sufficiently accurate
// rational approximation. Our initial idea was to divide by sqrt(2)
// by adjusting the quantization scale, but that would cause this
// function to tend to increase the magnitude of the elements of
// vec, which would resulting in clipping and therefore accuracy
// loss, especially compounded over 30+ transformer layers.
const int log2_sqrt_vec_size = log2_vec_size / 2;
constexpr int32_t qmin = -(1 << 15) + 1;
constexpr int32_t qmax = -qmin;
if (log2_vec_size % 2 != 0) {
// 408 / 577 - 1.0 / sqrt(2) ~= 1.062e-0.6, which should be close enough.
static const int32_t inv_sqrt_2_numerator = 408;
static const int32_t inv_sqrt_2_denominator = 577;
for (int ii = 0; ii < vec_size; ++ii) {
const auto val_over_sqrt_vec_size =
(tmp[ii] * inv_sqrt_2_numerator / inv_sqrt_2_denominator) >>
log2_sqrt_vec_size;
vec[ii] = std::clamp(val_over_sqrt_vec_size, qmin, qmax);
}
} else {
for (int ii = 0; ii < vec_size; ++ii) {
vec[ii] = std::clamp(tmp[ii] >> log2_sqrt_vec_size, qmin, qmax);
}
}
return;
internal::quantized_normalize_after_fht(
tmp.get(), vec, log2_vec_size, vec_size);
}

// Like fast_hadamard_transform, but vec must be of length 28 * (1 <<
Expand All @@ -163,4 +171,31 @@ void fast_hadamard_transform_28N(T* vec, int log2_vec_size) {
}
}

// We don't need the quantization scale; see the function-level
// comment on fast_hadamard_transform_symmetric_quantized_s16 for
// details.
void fast_hadamard_transform_symmetric_quantized_s16_28N(
int16_t* vec,
int log2_vec_size) {
if (log2_vec_size == 0) {
return;
}
const int vec_size = (1 << log2_vec_size);

auto tmp = std::make_unique<int32_t[]>(vec_size);
std::copy(vec, vec + vec_size * 28, tmp.get());

for (int ii = 0; ii < 28; ++ii) {
internal::fast_hadamard_transform_unnormalized_simple_impl(
&tmp[ii * vec_size], log2_vec_size);
}

for (int ii = 0; ii < vec_size; ++ii) {
hadamard_mult_28_strided(&tmp[ii], vec_size);
}

internal::quantized_normalize_after_fht(
tmp.get(), vec, log2_vec_size, vec_size * 28);
}

} // namespace executorch
Loading