diff --git a/.lintrunner.toml b/.lintrunner.toml index c28512c5986..eca965bb1e6 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -74,6 +74,8 @@ exclude_patterns = [ # NB: Objective-C is not supported 'examples/apple/**', 'examples/demo-apps/apple_ios/**', + # File contains @generated + 'extension/llm/custom_ops/spinquant/fast_hadamard_transform_special.h', ] command = [ 'python', @@ -177,6 +179,8 @@ exclude_patterns = [ '**/*.bat', '**/*.jpg', '**/*.jar', + # File contains @generated + 'extension/llm/custom_ops/spinquant/fast_hadamard_transform_special.h', ] command = [ 'python', diff --git a/extension/llm/custom_ops/spinquant/README.md b/extension/llm/custom_ops/spinquant/README.md new file mode 100644 index 00000000000..e946e0ee60e --- /dev/null +++ b/extension/llm/custom_ops/spinquant/README.md @@ -0,0 +1,16 @@ +# SpinQuant + +This is an implementation of the [Fast Hadamard +Transform](https://en.wikipedia.org/wiki/Fast_Walsh–Hadamard_transform) +as used in [SpinQuant](https://arxiv.org/abs/2405.16406) (for the R3 +and R4 matrices), [QuaRot](https://arxiv.org/abs/2404.00456), and +[Quip#](https://arxiv.org/pdf/2402.04396). We follow those papers' +method (as implemented in +https://github.com/Dao-AILab/fast-hadamard-transform/) for extending +the transform to non-power-of-two input sizes. CUDA is not considered +because https://github.com/Dao-AILab/fast-hadamard-transform/ is +already available. + +The intended long-term destination for this code is pytorch/ao; it is +in ExecuTorch temporarily until we get C++ dependency from ExecuTorch +on torchao figured out. diff --git a/extension/llm/custom_ops/spinquant/TARGETS b/extension/llm/custom_ops/spinquant/TARGETS new file mode 100644 index 00000000000..0a42614a385 --- /dev/null +++ b/extension/llm/custom_ops/spinquant/TARGETS @@ -0,0 +1,5 @@ +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() diff --git a/extension/llm/custom_ops/spinquant/fast_hadamard_transform.h b/extension/llm/custom_ops/spinquant/fast_hadamard_transform.h new file mode 100644 index 00000000000..d481567b8be --- /dev/null +++ b/extension/llm/custom_ops/spinquant/fast_hadamard_transform.h @@ -0,0 +1,93 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// (c) Meta Platforms, Inc. and affiliates. +#pragma once + +#include +#include +#include + +#include "fast_hadamard_transform_special.h" + +namespace executorch { +namespace internal { + +// Square root of 1 << log2_n. +template +T fast_sqrt_of_power_of_2(int log2_n) { + // The square root of 2**N is, by definition, 2**(N/2), which is + // trivial to compute for even N using a left shift. + // + // For odd N, 2**(N/2) = 2**(floor(N/2) + 1/2) + // = 2**(floor(N/2)) * (2 ** (1/2)) + // = 2**(floor(N/2)) * sqrt(2) + // which is again fast to compute. + return T(1 << (log2_n / 2)) * ((log2_n % 2) ? T(std::sqrt(2)) : T(1)); +} + +template +void normalize_after_fht(T* out, int log2_vec_size) { + const T inv_sqrt = T(1) / fast_sqrt_of_power_of_2(log2_vec_size); + const int vec_size = 1 << log2_vec_size; + for (int ii = 0; ii < vec_size; ++ii) { + out[ii] *= inv_sqrt; + } +} + +template +void fast_hadamard_transform_simple_impl(T* vec, int log2_vec_size) { + if (log2_vec_size == 0) { + return; + } + + int step = 1; + const auto vec_size = 1 << log2_vec_size; + while (step < vec_size) { + for (int ii = 0; ii < vec_size; ii += step * 2) { + for (int jj = ii; jj < ii + step; ++jj) { + auto x = vec[jj]; + auto y = vec[jj + step]; + vec[jj] = x + y; + vec[jj + step] = x - y; + } + } + step *= 2; + } + + normalize_after_fht(vec, log2_vec_size); +} + +} // namespace internal + +// Compute the fast Walsh-Hadamard transform +// (https://en.wikipedia.org/wiki/Fast_Walsh%E2%80%93Hadamard_transform) +// of vec, which must be of length (1 << log2_vec_size). +template +void fast_hadamard_transform(T* vec, int log2_vec_size) { + internal::fast_hadamard_transform_simple_impl(vec, log2_vec_size); +} + +// Like fast_hadamard_transform, but vec must be of length 28 * (1 << +// log2_vec_size) and the transform is computed by interpreting vec as +// a (28, 1 << log2_vec_size) matrix and performing 28 FHTs, followed +// by (1 << log2_vec_size) multiplications by a particular Hadamard +// matrix of size 28x28 (see special_hadamard_code_gen.py for the +// exact matrix). +template +void fast_hadamard_transform_28N(T* vec, int log2_vec_size) { + const int vec_size = (1 << log2_vec_size); + for (int ii = 0; ii < 28; ++ii) { + fast_hadamard_transform(&vec[ii * vec_size], log2_vec_size); + } + for (int ii = 0; ii < vec_size; ++ii) { + hadamard_mult_28_strided(&vec[ii], vec_size); + } +} + +} // namespace executorch diff --git a/extension/llm/custom_ops/spinquant/fast_hadamard_transform_special.h b/extension/llm/custom_ops/spinquant/fast_hadamard_transform_special.h new file mode 100644 index 00000000000..ca5a8d61e73 --- /dev/null +++ b/extension/llm/custom_ops/spinquant/fast_hadamard_transform_special.h @@ -0,0 +1,241 @@ +// @generated by special_hadamard_code_gen.py strided_cpu + + +#pragma once + + +template +void hadamard_mult_12_strided(T* input, int stride) { + T x[12]; + T out[12]; + x[0] = input[0 * stride]; + x[1] = input[1 * stride]; + x[2] = input[2 * stride]; + x[3] = input[3 * stride]; + x[4] = input[4 * stride]; + x[5] = input[5 * stride]; + x[6] = input[6 * stride]; + x[7] = input[7 * stride]; + x[8] = input[8 * stride]; + x[9] = input[9 * stride]; + x[10] = input[10 * stride]; + x[11] = input[11 * stride]; + out[0] = + x[0] - x[1] - x[2] - x[3] - x[4] - x[5] - x[6] - x[7] - x[8] - x[9] - x[10] - x[11]; + out[1] = + x[0] + x[1] - x[2] + x[3] - x[4] - x[5] - x[6] + x[7] + x[8] + x[9] - x[10] + x[11]; + out[2] = + x[0] + x[1] + x[2] - x[3] + x[4] - x[5] - x[6] - x[7] + x[8] + x[9] + x[10] - x[11]; + out[3] = + x[0] - x[1] + x[2] + x[3] - x[4] + x[5] - x[6] - x[7] - x[8] + x[9] + x[10] + x[11]; + out[4] = + x[0] + x[1] - x[2] + x[3] + x[4] - x[5] + x[6] - x[7] - x[8] - x[9] + x[10] + x[11]; + out[5] = + x[0] + x[1] + x[2] - x[3] + x[4] + x[5] - x[6] + x[7] - x[8] - x[9] - x[10] + x[11]; + out[6] = + x[0] + x[1] + x[2] + x[3] - x[4] + x[5] + x[6] - x[7] + x[8] - x[9] - x[10] - x[11]; + out[7] = + x[0] - x[1] + x[2] + x[3] + x[4] - x[5] + x[6] + x[7] - x[8] + x[9] - x[10] - x[11]; + out[8] = + x[0] - x[1] - x[2] + x[3] + x[4] + x[5] - x[6] + x[7] + x[8] - x[9] + x[10] - x[11]; + out[9] = + x[0] - x[1] - x[2] - x[3] + x[4] + x[5] + x[6] - x[7] + x[8] + x[9] - x[10] + x[11]; + out[10] = + x[0] + x[1] - x[2] - x[3] - x[4] + x[5] + x[6] + x[7] - x[8] + x[9] + x[10] - x[11]; + out[11] = + x[0] - x[1] + x[2] - x[3] - x[4] - x[5] + x[6] + x[7] + x[8] - x[9] + x[10] + x[11]; + #pragma unroll + for (int ii = 0; ii < 12; ++ii) { input[stride * ii] = out[ii]; } +} + + +template +void hadamard_mult_20_strided(T* input, int stride) { + T x[20]; + T out[20]; + x[0] = input[0 * stride]; + x[1] = input[1 * stride]; + x[2] = input[2 * stride]; + x[3] = input[3 * stride]; + x[4] = input[4 * stride]; + x[5] = input[5 * stride]; + x[6] = input[6 * stride]; + x[7] = input[7 * stride]; + x[8] = input[8 * stride]; + x[9] = input[9 * stride]; + x[10] = input[10 * stride]; + x[11] = input[11 * stride]; + x[12] = input[12 * stride]; + x[13] = input[13 * stride]; + x[14] = input[14 * stride]; + x[15] = input[15 * stride]; + x[16] = input[16 * stride]; + x[17] = input[17 * stride]; + x[18] = input[18 * stride]; + x[19] = input[19 * stride]; + out[0] = + x[0] - x[1] - x[2] - x[3] - x[4] + x[5] - x[6] - x[7] - x[8] - x[9] + x[10] + x[11] - x[12] - x[13] + x[14] + x[15] - x[16] + x[17] + x[18] - x[19]; + out[1] = - x[0] + x[1] - x[2] - x[3] - x[4] - x[5] + x[6] - x[7] - x[8] - x[9] + x[10] + x[11] + x[12] - x[13] - x[14] - x[15] + x[16] - x[17] + x[18] + x[19]; + out[2] = - x[0] - x[1] + x[2] - x[3] - x[4] - x[5] - x[6] + x[7] - x[8] - x[9] - x[10] + x[11] + x[12] + x[13] - x[14] + x[15] - x[16] + x[17] - x[18] + x[19]; + out[3] = - x[0] - x[1] - x[2] + x[3] - x[4] - x[5] - x[6] - x[7] + x[8] - x[9] - x[10] - x[11] + x[12] + x[13] + x[14] + x[15] + x[16] - x[17] + x[18] - x[19]; + out[4] = - x[0] - x[1] - x[2] - x[3] + x[4] - x[5] - x[6] - x[7] - x[8] + x[9] + x[10] - x[11] - x[12] + x[13] + x[14] - x[15] + x[16] + x[17] - x[18] + x[19]; + out[5] = - x[0] + x[1] + x[2] + x[3] + x[4] + x[5] - x[6] - x[7] - x[8] - x[9] - x[10] + x[11] - x[12] - x[13] + x[14] + x[15] + x[16] - x[17] - x[18] + x[19]; + out[6] = + x[0] - x[1] + x[2] + x[3] + x[4] - x[5] + x[6] - x[7] - x[8] - x[9] + x[10] - x[11] + x[12] - x[13] - x[14] + x[15] + x[16] + x[17] - x[18] - x[19]; + out[7] = + x[0] + x[1] - x[2] + x[3] + x[4] - x[5] - x[6] + x[7] - x[8] - x[9] - x[10] + x[11] - x[12] + x[13] - x[14] - x[15] + x[16] + x[17] + x[18] - x[19]; + out[8] = + x[0] + x[1] + x[2] - x[3] + x[4] - x[5] - x[6] - x[7] + x[8] - x[9] - x[10] - x[11] + x[12] - x[13] + x[14] - x[15] - x[16] + x[17] + x[18] + x[19]; + out[9] = + x[0] + x[1] + x[2] + x[3] - x[4] - x[5] - x[6] - x[7] - x[8] + x[9] + x[10] - x[11] - x[12] + x[13] - x[14] + x[15] - x[16] - x[17] + x[18] + x[19]; + out[10] = - x[0] - x[1] + x[2] + x[3] - x[4] + x[5] - x[6] + x[7] + x[8] - x[9] + x[10] - x[11] - x[12] - x[13] - x[14] - x[15] + x[16] + x[17] + x[18] + x[19]; + out[11] = - x[0] - x[1] - x[2] + x[3] + x[4] - x[5] + x[6] - x[7] + x[8] + x[9] - x[10] + x[11] - x[12] - x[13] - x[14] + x[15] - x[16] + x[17] + x[18] + x[19]; + out[12] = + x[0] - x[1] - x[2] - x[3] + x[4] + x[5] - x[6] + x[7] - x[8] + x[9] - x[10] - x[11] + x[12] - x[13] - x[14] + x[15] + x[16] - x[17] + x[18] + x[19]; + out[13] = + x[0] + x[1] - x[2] - x[3] - x[4] + x[5] + x[6] - x[7] + x[8] - x[9] - x[10] - x[11] - x[12] + x[13] - x[14] + x[15] + x[16] + x[17] - x[18] + x[19]; + out[14] = - x[0] + x[1] + x[2] - x[3] - x[4] - x[5] + x[6] + x[7] - x[8] + x[9] - x[10] - x[11] - x[12] - x[13] + x[14] + x[15] + x[16] + x[17] + x[18] - x[19]; + out[15] = - x[0] + x[1] - x[2] - x[3] + x[4] - x[5] - x[6] + x[7] + x[8] - x[9] + x[10] - x[11] - x[12] - x[13] - x[14] + x[15] - x[16] - x[17] - x[18] - x[19]; + out[16] = + x[0] - x[1] + x[2] - x[3] - x[4] - x[5] - x[6] - x[7] + x[8] + x[9] - x[10] + x[11] - x[12] - x[13] - x[14] - x[15] + x[16] - x[17] - x[18] - x[19]; + out[17] = - x[0] + x[1] - x[2] + x[3] - x[4] + x[5] - x[6] - x[7] - x[8] + x[9] - x[10] - x[11] + x[12] - x[13] - x[14] - x[15] - x[16] + x[17] - x[18] - x[19]; + out[18] = - x[0] - x[1] + x[2] - x[3] + x[4] + x[5] + x[6] - x[7] - x[8] - x[9] - x[10] - x[11] - x[12] + x[13] - x[14] - x[15] - x[16] - x[17] + x[18] - x[19]; + out[19] = + x[0] - x[1] - x[2] + x[3] - x[4] - x[5] + x[6] + x[7] - x[8] - x[9] - x[10] - x[11] - x[12] - x[13] + x[14] - x[15] - x[16] - x[17] - x[18] + x[19]; + #pragma unroll + for (int ii = 0; ii < 20; ++ii) { input[stride * ii] = out[ii]; } +} + + +template +void hadamard_mult_28_strided(T* input, int stride) { + T x[28]; + T out[28]; + x[0] = input[0 * stride]; + x[1] = input[1 * stride]; + x[2] = input[2 * stride]; + x[3] = input[3 * stride]; + x[4] = input[4 * stride]; + x[5] = input[5 * stride]; + x[6] = input[6 * stride]; + x[7] = input[7 * stride]; + x[8] = input[8 * stride]; + x[9] = input[9 * stride]; + x[10] = input[10 * stride]; + x[11] = input[11 * stride]; + x[12] = input[12 * stride]; + x[13] = input[13 * stride]; + x[14] = input[14 * stride]; + x[15] = input[15 * stride]; + x[16] = input[16 * stride]; + x[17] = input[17 * stride]; + x[18] = input[18 * stride]; + x[19] = input[19 * stride]; + x[20] = input[20 * stride]; + x[21] = input[21 * stride]; + x[22] = input[22 * stride]; + x[23] = input[23 * stride]; + x[24] = input[24 * stride]; + x[25] = input[25 * stride]; + x[26] = input[26 * stride]; + x[27] = input[27 * stride]; + out[0] = + x[0] - x[1] - x[2] - x[3] - x[4] - x[5] - x[6] + x[7] + x[8] - x[9] - x[10] - x[11] - x[12] + x[13] + x[14] - x[15] + x[16] - x[17] - x[18] + x[19] - x[20] + x[21] - x[22] - x[23] + x[24] + x[25] - x[26] - x[27]; + out[1] = - x[0] + x[1] - x[2] - x[3] - x[4] - x[5] - x[6] + x[7] + x[8] + x[9] - x[10] - x[11] - x[12] - x[13] - x[14] + x[15] - x[16] + x[17] - x[18] - x[19] + x[20] - x[21] + x[22] - x[23] - x[24] + x[25] + x[26] - x[27]; + out[2] = - x[0] - x[1] + x[2] - x[3] - x[4] - x[5] - x[6] - x[7] + x[8] + x[9] + x[10] - x[11] - x[12] - x[13] + x[14] - x[15] + x[16] - x[17] + x[18] - x[19] - x[20] - x[21] - x[22] + x[23] - x[24] - x[25] + x[26] + x[27]; + out[3] = - x[0] - x[1] - x[2] + x[3] - x[4] - x[5] - x[6] - x[7] - x[8] + x[9] + x[10] + x[11] - x[12] - x[13] - x[14] + x[15] - x[16] + x[17] - x[18] + x[19] - x[20] + x[21] - x[22] - x[23] + x[24] - x[25] - x[26] + x[27]; + out[4] = - x[0] - x[1] - x[2] - x[3] + x[4] - x[5] - x[6] - x[7] - x[8] - x[9] + x[10] + x[11] + x[12] - x[13] - x[14] - x[15] + x[16] - x[17] + x[18] - x[19] + x[20] + x[21] + x[22] - x[23] - x[24] + x[25] - x[26] - x[27]; + out[5] = - x[0] - x[1] - x[2] - x[3] - x[4] + x[5] - x[6] - x[7] - x[8] - x[9] - x[10] + x[11] + x[12] + x[13] + x[14] - x[15] - x[16] + x[17] - x[18] + x[19] - x[20] - x[21] + x[22] + x[23] - x[24] - x[25] + x[26] - x[27]; + out[6] = - x[0] - x[1] - x[2] - x[3] - x[4] - x[5] + x[6] + x[7] - x[8] - x[9] - x[10] - x[11] + x[12] + x[13] - x[14] + x[15] - x[16] - x[17] + x[18] - x[19] + x[20] - x[21] - x[22] + x[23] + x[24] - x[25] - x[26] + x[27]; + out[7] = - x[0] - x[1] + x[2] + x[3] + x[4] + x[5] - x[6] + x[7] - x[8] - x[9] - x[10] - x[11] - x[12] - x[13] - x[14] + x[15] + x[16] - x[17] - x[18] + x[19] + x[20] + x[21] - x[22] + x[23] - x[24] - x[25] + x[26] - x[27]; + out[8] = - x[0] - x[1] - x[2] + x[3] + x[4] + x[5] + x[6] - x[7] + x[8] - x[9] - x[10] - x[11] - x[12] - x[13] + x[14] - x[15] + x[16] + x[17] - x[18] - x[19] + x[20] - x[21] + x[22] - x[23] + x[24] - x[25] - x[26] + x[27]; + out[9] = + x[0] - x[1] - x[2] - x[3] + x[4] + x[5] + x[6] - x[7] - x[8] + x[9] - x[10] - x[11] - x[12] - x[13] + x[14] + x[15] - x[16] + x[17] + x[18] - x[19] - x[20] + x[21] - x[22] + x[23] - x[24] + x[25] - x[26] - x[27]; + out[10] = + x[0] + x[1] - x[2] - x[3] - x[4] + x[5] + x[6] - x[7] - x[8] - x[9] + x[10] - x[11] - x[12] - x[13] - x[14] + x[15] + x[16] - x[17] + x[18] + x[19] - x[20] - x[21] + x[22] - x[23] + x[24] - x[25] + x[26] - x[27]; + out[11] = + x[0] + x[1] + x[2] - x[3] - x[4] - x[5] + x[6] - x[7] - x[8] - x[9] - x[10] + x[11] - x[12] - x[13] - x[14] - x[15] + x[16] + x[17] - x[18] + x[19] + x[20] - x[21] - x[22] + x[23] - x[24] + x[25] - x[26] + x[27]; + out[12] = + x[0] + x[1] + x[2] + x[3] - x[4] - x[5] - x[6] - x[7] - x[8] - x[9] - x[10] - x[11] + x[12] - x[13] + x[14] - x[15] - x[16] + x[17] + x[18] - x[19] + x[20] + x[21] - x[22] - x[23] + x[24] - x[25] + x[26] - x[27]; + out[13] = - x[0] + x[1] + x[2] + x[3] + x[4] - x[5] - x[6] - x[7] - x[8] - x[9] - x[10] - x[11] - x[12] + x[13] + x[14] + x[15] - x[16] - x[17] + x[18] + x[19] - x[20] - x[21] + x[22] - x[23] - x[24] + x[25] - x[26] + x[27]; + out[14] = - x[0] + x[1] - x[2] + x[3] + x[4] - x[5] + x[6] + x[7] - x[8] - x[9] + x[10] + x[11] - x[12] - x[13] + x[14] - x[15] - x[16] - x[17] - x[18] - x[19] - x[20] - x[21] - x[22] + x[23] + x[24] + x[25] + x[26] - x[27]; + out[15] = + x[0] - x[1] + x[2] - x[3] + x[4] + x[5] - x[6] - x[7] + x[8] - x[9] - x[10] + x[11] + x[12] - x[13] - x[14] + x[15] - x[16] - x[17] - x[18] - x[19] - x[20] - x[21] - x[22] - x[23] + x[24] + x[25] + x[26] + x[27]; + out[16] = - x[0] + x[1] - x[2] + x[3] - x[4] + x[5] + x[6] - x[7] - x[8] + x[9] - x[10] - x[11] + x[12] + x[13] - x[14] - x[15] + x[16] - x[17] - x[18] - x[19] - x[20] + x[21] - x[22] - x[23] - x[24] + x[25] + x[26] + x[27]; + out[17] = + x[0] - x[1] + x[2] - x[3] + x[4] - x[5] + x[6] + x[7] - x[8] - x[9] + x[10] - x[11] - x[12] + x[13] - x[14] - x[15] - x[16] + x[17] - x[18] - x[19] - x[20] + x[21] + x[22] - x[23] - x[24] - x[25] + x[26] + x[27]; + out[18] = + x[0] + x[1] - x[2] + x[3] - x[4] + x[5] - x[6] + x[7] + x[8] - x[9] - x[10] + x[11] - x[12] - x[13] - x[14] - x[15] - x[16] - x[17] + x[18] - x[19] - x[20] + x[21] + x[22] + x[23] - x[24] - x[25] - x[26] + x[27]; + out[19] = - x[0] + x[1] + x[2] - x[3] + x[4] - x[5] + x[6] - x[7] + x[8] + x[9] - x[10] - x[11] + x[12] - x[13] - x[14] - x[15] - x[16] - x[17] - x[18] + x[19] - x[20] + x[21] + x[22] + x[23] + x[24] - x[25] - x[26] - x[27]; + out[20] = + x[0] - x[1] + x[2] + x[3] - x[4] + x[5] - x[6] - x[7] - x[8] + x[9] + x[10] - x[11] - x[12] + x[13] - x[14] - x[15] - x[16] - x[17] - x[18] - x[19] + x[20] - x[21] + x[22] + x[23] + x[24] + x[25] - x[26] - x[27]; + out[21] = - x[0] + x[1] + x[2] - x[3] - x[4] + x[5] + x[6] - x[7] + x[8] - x[9] + x[10] + x[11] - x[12] + x[13] + x[14] + x[15] - x[16] - x[17] - x[18] - x[19] + x[20] + x[21] - x[22] - x[23] - x[24] - x[25] - x[26] - x[27]; + out[22] = + x[0] - x[1] + x[2] + x[3] - x[4] - x[5] + x[6] + x[7] - x[8] + x[9] - x[10] + x[11] + x[12] - x[13] + x[14] + x[15] + x[16] - x[17] - x[18] - x[19] - x[20] - x[21] + x[22] - x[23] - x[24] - x[25] - x[26] - x[27]; + out[23] = + x[0] + x[1] - x[2] + x[3] + x[4] - x[5] - x[6] - x[7] + x[8] - x[9] + x[10] - x[11] + x[12] + x[13] - x[14] + x[15] + x[16] + x[17] - x[18] - x[19] - x[20] - x[21] - x[22] + x[23] - x[24] - x[25] - x[26] - x[27]; + out[24] = - x[0] + x[1] + x[2] - x[3] + x[4] + x[5] - x[6] + x[7] - x[8] + x[9] - x[10] + x[11] - x[12] + x[13] - x[14] - x[15] + x[16] + x[17] + x[18] - x[19] - x[20] - x[21] - x[22] - x[23] + x[24] - x[25] - x[26] - x[27]; + out[25] = - x[0] - x[1] + x[2] + x[3] - x[4] + x[5] + x[6] + x[7] + x[8] - x[9] + x[10] - x[11] + x[12] - x[13] - x[14] - x[15] - x[16] + x[17] + x[18] + x[19] - x[20] - x[21] - x[22] - x[23] - x[24] + x[25] - x[26] - x[27]; + out[26] = + x[0] - x[1] - x[2] + x[3] + x[4] - x[5] + x[6] - x[7] + x[8] + x[9] - x[10] + x[11] - x[12] + x[13] - x[14] - x[15] - x[16] - x[17] + x[18] + x[19] + x[20] - x[21] - x[22] - x[23] - x[24] - x[25] + x[26] - x[27]; + out[27] = + x[0] + x[1] - x[2] - x[3] + x[4] + x[5] - x[6] + x[7] - x[8] + x[9] + x[10] - x[11] + x[12] - x[13] + x[14] - x[15] - x[16] - x[17] - x[18] + x[19] + x[20] - x[21] - x[22] - x[23] - x[24] - x[25] - x[26] + x[27]; + #pragma unroll + for (int ii = 0; ii < 28; ++ii) { input[stride * ii] = out[ii]; } +} + + +template +void hadamard_mult_40_strided(T* input, int stride) { + T x[40]; + T out[40]; + x[0] = input[0 * stride]; + x[1] = input[1 * stride]; + x[2] = input[2 * stride]; + x[3] = input[3 * stride]; + x[4] = input[4 * stride]; + x[5] = input[5 * stride]; + x[6] = input[6 * stride]; + x[7] = input[7 * stride]; + x[8] = input[8 * stride]; + x[9] = input[9 * stride]; + x[10] = input[10 * stride]; + x[11] = input[11 * stride]; + x[12] = input[12 * stride]; + x[13] = input[13 * stride]; + x[14] = input[14 * stride]; + x[15] = input[15 * stride]; + x[16] = input[16 * stride]; + x[17] = input[17 * stride]; + x[18] = input[18 * stride]; + x[19] = input[19 * stride]; + x[20] = input[20 * stride]; + x[21] = input[21 * stride]; + x[22] = input[22 * stride]; + x[23] = input[23 * stride]; + x[24] = input[24 * stride]; + x[25] = input[25 * stride]; + x[26] = input[26 * stride]; + x[27] = input[27 * stride]; + x[28] = input[28 * stride]; + x[29] = input[29 * stride]; + x[30] = input[30 * stride]; + x[31] = input[31 * stride]; + x[32] = input[32 * stride]; + x[33] = input[33 * stride]; + x[34] = input[34 * stride]; + x[35] = input[35 * stride]; + x[36] = input[36 * stride]; + x[37] = input[37 * stride]; + x[38] = input[38 * stride]; + x[39] = input[39 * stride]; + out[0] = + x[0] - x[1] - x[2] - x[3] - x[4] - x[5] - x[6] - x[7] - x[8] - x[9] - x[10] - x[11] - x[12] - x[13] - x[14] - x[15] - x[16] - x[17] - x[18] - x[19] + x[20] - x[21] - x[22] - x[23] - x[24] - x[25] - x[26] - x[27] - x[28] - x[29] - x[30] - x[31] - x[32] - x[33] - x[34] - x[35] - x[36] - x[37] - x[38] - x[39]; + out[1] = + x[0] + x[1] - x[2] + x[3] + x[4] - x[5] - x[6] - x[7] - x[8] + x[9] - x[10] + x[11] - x[12] + x[13] + x[14] + x[15] + x[16] - x[17] - x[18] + x[19] + x[20] + x[21] - x[22] + x[23] + x[24] - x[25] - x[26] - x[27] - x[28] + x[29] - x[30] + x[31] - x[32] + x[33] + x[34] + x[35] + x[36] - x[37] - x[38] + x[39]; + out[2] = + x[0] + x[1] + x[2] - x[3] + x[4] + x[5] - x[6] - x[7] - x[8] - x[9] + x[10] - x[11] + x[12] - x[13] + x[14] + x[15] + x[16] + x[17] - x[18] - x[19] + x[20] + x[21] + x[22] - x[23] + x[24] + x[25] - x[26] - x[27] - x[28] - x[29] + x[30] - x[31] + x[32] - x[33] + x[34] + x[35] + x[36] + x[37] - x[38] - x[39]; + out[3] = + x[0] - x[1] + x[2] + x[3] - x[4] + x[5] + x[6] - x[7] - x[8] - x[9] - x[10] + x[11] - x[12] + x[13] - x[14] + x[15] + x[16] + x[17] + x[18] - x[19] + x[20] - x[21] + x[22] + x[23] - x[24] + x[25] + x[26] - x[27] - x[28] - x[29] - x[30] + x[31] - x[32] + x[33] - x[34] + x[35] + x[36] + x[37] + x[38] - x[39]; + out[4] = + x[0] - x[1] - x[2] + x[3] + x[4] - x[5] + x[6] + x[7] - x[8] - x[9] - x[10] - x[11] + x[12] - x[13] + x[14] - x[15] + x[16] + x[17] + x[18] + x[19] + x[20] - x[21] - x[22] + x[23] + x[24] - x[25] + x[26] + x[27] - x[28] - x[29] - x[30] - x[31] + x[32] - x[33] + x[34] - x[35] + x[36] + x[37] + x[38] + x[39]; + out[5] = + x[0] + x[1] - x[2] - x[3] + x[4] + x[5] - x[6] + x[7] + x[8] - x[9] - x[10] - x[11] - x[12] + x[13] - x[14] + x[15] - x[16] + x[17] + x[18] + x[19] + x[20] + x[21] - x[22] - x[23] + x[24] + x[25] - x[26] + x[27] + x[28] - x[29] - x[30] - x[31] - x[32] + x[33] - x[34] + x[35] - x[36] + x[37] + x[38] + x[39]; + out[6] = + x[0] + x[1] + x[2] - x[3] - x[4] + x[5] + x[6] - x[7] + x[8] + x[9] - x[10] - x[11] - x[12] - x[13] + x[14] - x[15] + x[16] - x[17] + x[18] + x[19] + x[20] + x[21] + x[22] - x[23] - x[24] + x[25] + x[26] - x[27] + x[28] + x[29] - x[30] - x[31] - x[32] - x[33] + x[34] - x[35] + x[36] - x[37] + x[38] + x[39]; + out[7] = + x[0] + x[1] + x[2] + x[3] - x[4] - x[5] + x[6] + x[7] - x[8] + x[9] + x[10] - x[11] - x[12] - x[13] - x[14] + x[15] - x[16] + x[17] - x[18] + x[19] + x[20] + x[21] + x[22] + x[23] - x[24] - x[25] + x[26] + x[27] - x[28] + x[29] + x[30] - x[31] - x[32] - x[33] - x[34] + x[35] - x[36] + x[37] - x[38] + x[39]; + out[8] = + x[0] + x[1] + x[2] + x[3] + x[4] - x[5] - x[6] + x[7] + x[8] - x[9] + x[10] + x[11] - x[12] - x[13] - x[14] - x[15] + x[16] - x[17] + x[18] - x[19] + x[20] + x[21] + x[22] + x[23] + x[24] - x[25] - x[26] + x[27] + x[28] - x[29] + x[30] + x[31] - x[32] - x[33] - x[34] - x[35] + x[36] - x[37] + x[38] - x[39]; + out[9] = + x[0] - x[1] + x[2] + x[3] + x[4] + x[5] - x[6] - x[7] + x[8] + x[9] - x[10] + x[11] + x[12] - x[13] - x[14] - x[15] - x[16] + x[17] - x[18] + x[19] + x[20] - x[21] + x[22] + x[23] + x[24] + x[25] - x[26] - x[27] + x[28] + x[29] - x[30] + x[31] + x[32] - x[33] - x[34] - x[35] - x[36] + x[37] - x[38] + x[39]; + out[10] = + x[0] + x[1] - x[2] + x[3] + x[4] + x[5] + x[6] - x[7] - x[8] + x[9] + x[10] - x[11] + x[12] + x[13] - x[14] - x[15] - x[16] - x[17] + x[18] - x[19] + x[20] + x[21] - x[22] + x[23] + x[24] + x[25] + x[26] - x[27] - x[28] + x[29] + x[30] - x[31] + x[32] + x[33] - x[34] - x[35] - x[36] - x[37] + x[38] - x[39]; + out[11] = + x[0] - x[1] + x[2] - x[3] + x[4] + x[5] + x[6] + x[7] - x[8] - x[9] + x[10] + x[11] - x[12] + x[13] + x[14] - x[15] - x[16] - x[17] - x[18] + x[19] + x[20] - x[21] + x[22] - x[23] + x[24] + x[25] + x[26] + x[27] - x[28] - x[29] + x[30] + x[31] - x[32] + x[33] + x[34] - x[35] - x[36] - x[37] - x[38] + x[39]; + out[12] = + x[0] + x[1] - x[2] + x[3] - x[4] + x[5] + x[6] + x[7] + x[8] - x[9] - x[10] + x[11] + x[12] - x[13] + x[14] + x[15] - x[16] - x[17] - x[18] - x[19] + x[20] + x[21] - x[22] + x[23] - x[24] + x[25] + x[26] + x[27] + x[28] - x[29] - x[30] + x[31] + x[32] - x[33] + x[34] + x[35] - x[36] - x[37] - x[38] - x[39]; + out[13] = + x[0] - x[1] + x[2] - x[3] + x[4] - x[5] + x[6] + x[7] + x[8] + x[9] - x[10] - x[11] + x[12] + x[13] - x[14] + x[15] + x[16] - x[17] - x[18] - x[19] + x[20] - x[21] + x[22] - x[23] + x[24] - x[25] + x[26] + x[27] + x[28] + x[29] - x[30] - x[31] + x[32] + x[33] - x[34] + x[35] + x[36] - x[37] - x[38] - x[39]; + out[14] = + x[0] - x[1] - x[2] + x[3] - x[4] + x[5] - x[6] + x[7] + x[8] + x[9] + x[10] - x[11] - x[12] + x[13] + x[14] - x[15] + x[16] + x[17] - x[18] - x[19] + x[20] - x[21] - x[22] + x[23] - x[24] + x[25] - x[26] + x[27] + x[28] + x[29] + x[30] - x[31] - x[32] + x[33] + x[34] - x[35] + x[36] + x[37] - x[38] - x[39]; + out[15] = + x[0] - x[1] - x[2] - x[3] + x[4] - x[5] + x[6] - x[7] + x[8] + x[9] + x[10] + x[11] - x[12] - x[13] + x[14] + x[15] - x[16] + x[17] + x[18] - x[19] + x[20] - x[21] - x[22] - x[23] + x[24] - x[25] + x[26] - x[27] + x[28] + x[29] + x[30] + x[31] - x[32] - x[33] + x[34] + x[35] - x[36] + x[37] + x[38] - x[39]; + out[16] = + x[0] - x[1] - x[2] - x[3] - x[4] + x[5] - x[6] + x[7] - x[8] + x[9] + x[10] + x[11] + x[12] - x[13] - x[14] + x[15] + x[16] - x[17] + x[18] + x[19] + x[20] - x[21] - x[22] - x[23] - x[24] + x[25] - x[26] + x[27] - x[28] + x[29] + x[30] + x[31] + x[32] - x[33] - x[34] + x[35] + x[36] - x[37] + x[38] + x[39]; + out[17] = + x[0] + x[1] - x[2] - x[3] - x[4] - x[5] + x[6] - x[7] + x[8] - x[9] + x[10] + x[11] + x[12] + x[13] - x[14] - x[15] + x[16] + x[17] - x[18] + x[19] + x[20] + x[21] - x[22] - x[23] - x[24] - x[25] + x[26] - x[27] + x[28] - x[29] + x[30] + x[31] + x[32] + x[33] - x[34] - x[35] + x[36] + x[37] - x[38] + x[39]; + out[18] = + x[0] + x[1] + x[2] - x[3] - x[4] - x[5] - x[6] + x[7] - x[8] + x[9] - x[10] + x[11] + x[12] + x[13] + x[14] - x[15] - x[16] + x[17] + x[18] - x[19] + x[20] + x[21] + x[22] - x[23] - x[24] - x[25] - x[26] + x[27] - x[28] + x[29] - x[30] + x[31] + x[32] + x[33] + x[34] - x[35] - x[36] + x[37] + x[38] - x[39]; + out[19] = + x[0] - x[1] + x[2] + x[3] - x[4] - x[5] - x[6] - x[7] + x[8] - x[9] + x[10] - x[11] + x[12] + x[13] + x[14] + x[15] - x[16] - x[17] + x[18] + x[19] + x[20] - x[21] + x[22] + x[23] - x[24] - x[25] - x[26] - x[27] + x[28] - x[29] + x[30] - x[31] + x[32] + x[33] + x[34] + x[35] - x[36] - x[37] + x[38] + x[39]; + out[20] = + x[0] - x[1] - x[2] - x[3] - x[4] - x[5] - x[6] - x[7] - x[8] - x[9] - x[10] - x[11] - x[12] - x[13] - x[14] - x[15] - x[16] - x[17] - x[18] - x[19] - x[20] + x[21] + x[22] + x[23] + x[24] + x[25] + x[26] + x[27] + x[28] + x[29] + x[30] + x[31] + x[32] + x[33] + x[34] + x[35] + x[36] + x[37] + x[38] + x[39]; + out[21] = + x[0] + x[1] - x[2] + x[3] + x[4] - x[5] - x[6] - x[7] - x[8] + x[9] - x[10] + x[11] - x[12] + x[13] + x[14] + x[15] + x[16] - x[17] - x[18] + x[19] - x[20] - x[21] + x[22] - x[23] - x[24] + x[25] + x[26] + x[27] + x[28] - x[29] + x[30] - x[31] + x[32] - x[33] - x[34] - x[35] - x[36] + x[37] + x[38] - x[39]; + out[22] = + x[0] + x[1] + x[2] - x[3] + x[4] + x[5] - x[6] - x[7] - x[8] - x[9] + x[10] - x[11] + x[12] - x[13] + x[14] + x[15] + x[16] + x[17] - x[18] - x[19] - x[20] - x[21] - x[22] + x[23] - x[24] - x[25] + x[26] + x[27] + x[28] + x[29] - x[30] + x[31] - x[32] + x[33] - x[34] - x[35] - x[36] - x[37] + x[38] + x[39]; + out[23] = + x[0] - x[1] + x[2] + x[3] - x[4] + x[5] + x[6] - x[7] - x[8] - x[9] - x[10] + x[11] - x[12] + x[13] - x[14] + x[15] + x[16] + x[17] + x[18] - x[19] - x[20] + x[21] - x[22] - x[23] + x[24] - x[25] - x[26] + x[27] + x[28] + x[29] + x[30] - x[31] + x[32] - x[33] + x[34] - x[35] - x[36] - x[37] - x[38] + x[39]; + out[24] = + x[0] - x[1] - x[2] + x[3] + x[4] - x[5] + x[6] + x[7] - x[8] - x[9] - x[10] - x[11] + x[12] - x[13] + x[14] - x[15] + x[16] + x[17] + x[18] + x[19] - x[20] + x[21] + x[22] - x[23] - x[24] + x[25] - x[26] - x[27] + x[28] + x[29] + x[30] + x[31] - x[32] + x[33] - x[34] + x[35] - x[36] - x[37] - x[38] - x[39]; + out[25] = + x[0] + x[1] - x[2] - x[3] + x[4] + x[5] - x[6] + x[7] + x[8] - x[9] - x[10] - x[11] - x[12] + x[13] - x[14] + x[15] - x[16] + x[17] + x[18] + x[19] - x[20] - x[21] + x[22] + x[23] - x[24] - x[25] + x[26] - x[27] - x[28] + x[29] + x[30] + x[31] + x[32] - x[33] + x[34] - x[35] + x[36] - x[37] - x[38] - x[39]; + out[26] = + x[0] + x[1] + x[2] - x[3] - x[4] + x[5] + x[6] - x[7] + x[8] + x[9] - x[10] - x[11] - x[12] - x[13] + x[14] - x[15] + x[16] - x[17] + x[18] + x[19] - x[20] - x[21] - x[22] + x[23] + x[24] - x[25] - x[26] + x[27] - x[28] - x[29] + x[30] + x[31] + x[32] + x[33] - x[34] + x[35] - x[36] + x[37] - x[38] - x[39]; + out[27] = + x[0] + x[1] + x[2] + x[3] - x[4] - x[5] + x[6] + x[7] - x[8] + x[9] + x[10] - x[11] - x[12] - x[13] - x[14] + x[15] - x[16] + x[17] - x[18] + x[19] - x[20] - x[21] - x[22] - x[23] + x[24] + x[25] - x[26] - x[27] + x[28] - x[29] - x[30] + x[31] + x[32] + x[33] + x[34] - x[35] + x[36] - x[37] + x[38] - x[39]; + out[28] = + x[0] + x[1] + x[2] + x[3] + x[4] - x[5] - x[6] + x[7] + x[8] - x[9] + x[10] + x[11] - x[12] - x[13] - x[14] - x[15] + x[16] - x[17] + x[18] - x[19] - x[20] - x[21] - x[22] - x[23] - x[24] + x[25] + x[26] - x[27] - x[28] + x[29] - x[30] - x[31] + x[32] + x[33] + x[34] + x[35] - x[36] + x[37] - x[38] + x[39]; + out[29] = + x[0] - x[1] + x[2] + x[3] + x[4] + x[5] - x[6] - x[7] + x[8] + x[9] - x[10] + x[11] + x[12] - x[13] - x[14] - x[15] - x[16] + x[17] - x[18] + x[19] - x[20] + x[21] - x[22] - x[23] - x[24] - x[25] + x[26] + x[27] - x[28] - x[29] + x[30] - x[31] - x[32] + x[33] + x[34] + x[35] + x[36] - x[37] + x[38] - x[39]; + out[30] = + x[0] + x[1] - x[2] + x[3] + x[4] + x[5] + x[6] - x[7] - x[8] + x[9] + x[10] - x[11] + x[12] + x[13] - x[14] - x[15] - x[16] - x[17] + x[18] - x[19] - x[20] - x[21] + x[22] - x[23] - x[24] - x[25] - x[26] + x[27] + x[28] - x[29] - x[30] + x[31] - x[32] - x[33] + x[34] + x[35] + x[36] + x[37] - x[38] + x[39]; + out[31] = + x[0] - x[1] + x[2] - x[3] + x[4] + x[5] + x[6] + x[7] - x[8] - x[9] + x[10] + x[11] - x[12] + x[13] + x[14] - x[15] - x[16] - x[17] - x[18] + x[19] - x[20] + x[21] - x[22] + x[23] - x[24] - x[25] - x[26] - x[27] + x[28] + x[29] - x[30] - x[31] + x[32] - x[33] - x[34] + x[35] + x[36] + x[37] + x[38] - x[39]; + out[32] = + x[0] + x[1] - x[2] + x[3] - x[4] + x[5] + x[6] + x[7] + x[8] - x[9] - x[10] + x[11] + x[12] - x[13] + x[14] + x[15] - x[16] - x[17] - x[18] - x[19] - x[20] - x[21] + x[22] - x[23] + x[24] - x[25] - x[26] - x[27] - x[28] + x[29] + x[30] - x[31] - x[32] + x[33] - x[34] - x[35] + x[36] + x[37] + x[38] + x[39]; + out[33] = + x[0] - x[1] + x[2] - x[3] + x[4] - x[5] + x[6] + x[7] + x[8] + x[9] - x[10] - x[11] + x[12] + x[13] - x[14] + x[15] + x[16] - x[17] - x[18] - x[19] - x[20] + x[21] - x[22] + x[23] - x[24] + x[25] - x[26] - x[27] - x[28] - x[29] + x[30] + x[31] - x[32] - x[33] + x[34] - x[35] - x[36] + x[37] + x[38] + x[39]; + out[34] = + x[0] - x[1] - x[2] + x[3] - x[4] + x[5] - x[6] + x[7] + x[8] + x[9] + x[10] - x[11] - x[12] + x[13] + x[14] - x[15] + x[16] + x[17] - x[18] - x[19] - x[20] + x[21] + x[22] - x[23] + x[24] - x[25] + x[26] - x[27] - x[28] - x[29] - x[30] + x[31] + x[32] - x[33] - x[34] + x[35] - x[36] - x[37] + x[38] + x[39]; + out[35] = + x[0] - x[1] - x[2] - x[3] + x[4] - x[5] + x[6] - x[7] + x[8] + x[9] + x[10] + x[11] - x[12] - x[13] + x[14] + x[15] - x[16] + x[17] + x[18] - x[19] - x[20] + x[21] + x[22] + x[23] - x[24] + x[25] - x[26] + x[27] - x[28] - x[29] - x[30] - x[31] + x[32] + x[33] - x[34] - x[35] + x[36] - x[37] - x[38] + x[39]; + out[36] = + x[0] - x[1] - x[2] - x[3] - x[4] + x[5] - x[6] + x[7] - x[8] + x[9] + x[10] + x[11] + x[12] - x[13] - x[14] + x[15] + x[16] - x[17] + x[18] + x[19] - x[20] + x[21] + x[22] + x[23] + x[24] - x[25] + x[26] - x[27] + x[28] - x[29] - x[30] - x[31] - x[32] + x[33] + x[34] - x[35] - x[36] + x[37] - x[38] - x[39]; + out[37] = + x[0] + x[1] - x[2] - x[3] - x[4] - x[5] + x[6] - x[7] + x[8] - x[9] + x[10] + x[11] + x[12] + x[13] - x[14] - x[15] + x[16] + x[17] - x[18] + x[19] - x[20] - x[21] + x[22] + x[23] + x[24] + x[25] - x[26] + x[27] - x[28] + x[29] - x[30] - x[31] - x[32] - x[33] + x[34] + x[35] - x[36] - x[37] + x[38] - x[39]; + out[38] = + x[0] + x[1] + x[2] - x[3] - x[4] - x[5] - x[6] + x[7] - x[8] + x[9] - x[10] + x[11] + x[12] + x[13] + x[14] - x[15] - x[16] + x[17] + x[18] - x[19] - x[20] - x[21] - x[22] + x[23] + x[24] + x[25] + x[26] - x[27] + x[28] - x[29] + x[30] - x[31] - x[32] - x[33] - x[34] + x[35] + x[36] - x[37] - x[38] + x[39]; + out[39] = + x[0] - x[1] + x[2] + x[3] - x[4] - x[5] - x[6] - x[7] + x[8] - x[9] + x[10] - x[11] + x[12] + x[13] + x[14] + x[15] - x[16] - x[17] + x[18] + x[19] - x[20] + x[21] - x[22] - x[23] + x[24] + x[25] + x[26] + x[27] - x[28] + x[29] - x[30] + x[31] - x[32] - x[33] - x[34] - x[35] + x[36] + x[37] - x[38] - x[39]; + #pragma unroll + for (int ii = 0; ii < 40; ++ii) { input[stride * ii] = out[ii]; } +} + diff --git a/extension/llm/custom_ops/spinquant/special_hadamard_code_gen.py b/extension/llm/custom_ops/spinquant/special_hadamard_code_gen.py new file mode 100644 index 00000000000..a8b9feb0785 --- /dev/null +++ b/extension/llm/custom_ops/spinquant/special_hadamard_code_gen.py @@ -0,0 +1,279 @@ +# Portions (c) Meta Platforms, Inc. and affiliates. +# This file is adapted from +# https://github.com/Dao-AILab/fast-hadamard-transform/blob/master/csrc/code_gen.py . + +# BSD 3-Clause License + +# Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from pathlib import Path + +import numpy as np + +# From https://en.wikipedia.org/wiki/Paley_construction (construction II for q = 5) + +had_12_paley = """ ++-++++++++++ +--+-+-+-+-+- ++++-++----++ ++---+--+-++- ++++++-++---- ++-+---+--+-+ +++--+++-++-- ++--++---+--+ +++----+++-++ ++--+-++---+- +++++----+++- ++-+--+-++--- +""" + +# From http://neilsloane.com/hadamard/ + +had_12 = """ ++----------- +++-+---+++-+ ++++-+---+++- ++-++-+---+++ +++-++-+---++ ++++-++-+---+ +++++-++-+--- ++-+++-++-+-- ++--+++-++-+- ++---+++-++-+ +++---+++-++- ++-+---+++-++ +""" + +had_20_will = """ ++----+----++--++-++- +-+----+---+++---+-++ +--+----+---+++-+-+-+ +---+----+---+++++-+- +----+----++--++-++-+ +-+++++-----+--+++--+ ++-+++-+---+-+--+++-- +++-++--+---+-+--+++- ++++-+---+---+-+--+++ +++++-----++--+-+--++ +--++-+-++-+-----++++ +---++-+-++-+---+-+++ ++---++-+-+--+--++-++ +++---++-+----+-+++-+ +-++---++-+----+++++- +-+--+--++-+----+---- ++-+-----++-+----+--- +-+-+-+---+--+----+-- +--+-+++------+----+- ++--+--++------+----+ +""" + + +had_28_will = """ ++------++----++-+--+-+--++-- +-+-----+++-----+-+--+-+--++- +--+-----+++---+-+-+----+--++ +---+-----+++---+-+-+-+--+--+ +----+-----+++---+-+-+++--+-- +-----+-----++++--+-+--++--+- +------++----++-+--+-+--++--+ +--++++-+-------++--+++-+--+- +---++++-+-----+-++--+-+-+--+ ++---+++--+----++-++--+-+-+-- +++---++---+----++-++--+-+-+- ++++---+----+----++-++--+-+-+ +++++--------+-+--++-++--+-+- +-++++--------+++--++--+--+-+ +-+-++-++--++--+--------++++- ++-+-++--+--++--+--------++++ +-+-+-++--+--++--+----+---+++ ++-+-+-++--+--+---+---++---++ +++-+-+-++--+------+--+++---+ +-++-+-+-++--+------+-++++--- ++-++-+---++--+------+-++++-- +-++--++-+-++-+++----++------ ++-++--++-+-++-+++-----+----- +++-++---+-+-++-+++-----+---- +-++-++-+-+-+-+--+++-----+--- +--++-++++-+-+----+++-----+-- ++--++-+-++-+-+----+++-----+- +++--++-+-++-+-+----++------+ +""" + + +had_40_tpal = """ ++-------------------+------------------- +++-++----+-+-++++--+++-++----+-+-++++--+ ++++-++----+-+-++++--+++-++----+-+-++++-- ++-++-++----+-+-++++-+-++-++----+-+-++++- ++--++-++----+-+-+++++--++-++----+-+-++++ +++--++-++----+-+-+++++--++-++----+-+-+++ ++++--++-++----+-+-+++++--++-++----+-+-++ +++++--++-++----+-+-+++++--++-++----+-+-+ ++++++--++-++----+-+-+++++--++-++----+-+- ++-++++--++-++----+-++-++++--++-++----+-+ +++-++++--++-++----+-++-++++--++-++----+- ++-+-++++--++-++----++-+-++++--++-++----+ +++-+-++++--++-++----++-+-++++--++-++---- ++-+-+-++++--++-++---+-+-+-++++--++-++--- ++--+-+-++++--++-++--+--+-+-++++--++-++-- ++---+-+-++++--++-++-+---+-+-++++--++-++- ++----+-+-++++--++-+++----+-+-++++--++-++ +++----+-+-++++--++-+++----+-+-++++--++-+ ++++----+-+-++++--++-+++----+-+-++++--++- ++-++----+-+-++++--+++-++----+-+-++++--++ ++--------------------+++++++++++++++++++ +++-++----+-+-++++--+--+--++++-+-+----++- ++++-++----+-+-++++-----+--++++-+-+----++ ++-++-++----+-+-++++--+--+--++++-+-+----+ ++--++-++----+-+-++++-++--+--++++-+-+---- +++--++-++----+-+-+++--++--+--++++-+-+--- ++++--++-++----+-+-++---++--+--++++-+-+-- +++++--++-++----+-+-+----++--+--++++-+-+- ++++++--++-++----+-+------++--+--++++-+-+ ++-++++--++-++----+-+-+----++--+--++++-+- +++-++++--++-++----+---+----++--+--++++-+ ++-+-++++--++-++----+-+-+----++--+--++++- +++-+-++++--++-++------+-+----++--+--++++ ++-+-+-++++--++-++----+-+-+----++--+--+++ ++--+-+-++++--++-++---++-+-+----++--+--++ ++---+-+-++++--++-++--+++-+-+----++--+--+ ++----+-+-++++--++-++-++++-+-+----++--+-- +++----+-+-++++--++-+--++++-+-+----++--+- ++++----+-+-++++--++----++++-+-+----++--+ ++-++----+-+-++++--++-+--++++-+-+----++-- +""" + +# NOTE: the original Dao-AILab/fast-hadamard-transform uses had_12_paley rather than +# had_12 here. However, SpinQuant and QuaRot seem to use had_12, so we follow them here. +had_strings = [had_12, had_20_will, had_28_will, had_40_tpal] + +header = """ + +#pragma once + +""" + + +TEMPLATE = """ +__device__ __forceinline__ void hadamard_mult_thread_{N}(float x[{N}]) {{ + float out[{N}]; + {code} + #pragma unroll + for (int i = 0; i < {N}; i++) {{ x[i] = out[i]; }} +}} + +""" + + +CPU_TEMPLATE = """ +template +void hadamard_mult_{N}(T* x) {{ + float out[{N}]; + {code} + #pragma unroll + for (int i = 0; i < {N}; i++) {{ x[i] = out[i]; }} +}} + +""" + +STRIDED_CPU_TEMPLATE = """ +template +void hadamard_mult_{N}_strided(T* input, int stride) {{ + T x[{N}]; + T out[{N}]; + {strided_load_code} + {code} + #pragma unroll + for (int ii = 0; ii < {N}; ++ii) {{ input[stride * ii] = out[ii]; }} +}} + +""" + + +def string_to_array(string): + # Convert strings of + and - to bool arrays + string = string.strip().replace("+", "1").replace("-", "-1").split() + return np.stack( + [ + np.fromstring(" ".join(string[i]), dtype=np.int32, sep=" ") + for i in range(len(string)) + ] + ) + + +def strided_load_code_gen(N): + return "\n ".join([f"x[{i}] = input[{i} * stride];" for i in range(N)]) + + +def array_code_gen(arr, template): + N = arr.shape[0] + assert arr.shape[0] == arr.shape[1] + out = [] + for i in range(N): + out.append( + f"out[{i}] = " + + " ".join([f"{'+' if arr[i, j] == 1 else '-'} x[{j}]" for j in range(N)]) + + ";" + ) + return template.format( + N=str(N), code="\n ".join(out), strided_load_code=strided_load_code_gen(N) + ) + + +OPTION_TO_TEMPLATE = { + "cuda": TEMPLATE, + "cpu": CPU_TEMPLATE, + "strided_cpu": STRIDED_CPU_TEMPLATE, +} + + +def main(option="cuda"): + try: + template = OPTION_TO_TEMPLATE[option] + except KeyError: + raise Exception( + f"bad target option {option}; options are {', '.join(OPTION_TO_TEMPLATE.keys())}" + ) + output_dir = Path(__file__).parent / "fast_hadamard_transform_special.h" + generated_line = f"// @{'generated'} by special_hadamard_code_gen.py {option}\n" + + output_dir.write_text( + generated_line + + header + + "".join(array_code_gen(string_to_array(s), template) for s in had_strings) + ) + + +if __name__ == "__main__": + import sys + + option = "cuda" + if len(sys.argv) > 1: + option = sys.argv[1] + main(option) diff --git a/extension/llm/custom_ops/spinquant/targets.bzl b/extension/llm/custom_ops/spinquant/targets.bzl new file mode 100644 index 00000000000..42fa472548b --- /dev/null +++ b/extension/llm/custom_ops/spinquant/targets.bzl @@ -0,0 +1,16 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + """Defines targets that should be shared between fbcode and xplat. + + The directory containing this targets.bzl file should also contain both + TARGETS and BUCK files that call this function. + """ + runtime.cxx_library( + name = "fast_hadamard_transform", + exported_headers = [ + "fast_hadamard_transform.h", + "fast_hadamard_transform_special.h", + ], + visibility = ["@EXECUTORCH_CLIENTS"], + )