diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 1321fb8d99..1507090903 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -493,7 +493,7 @@ MTL::Library* Device::build_library_(const std::string& source_string) { NS::Error* error = nullptr; auto options = MTL::CompileOptions::alloc()->init(); options->setFastMathEnabled(false); - options->setLanguageVersion(get_metal_version()); + // options->setLanguageVersion(get_metal_version()); auto mtl_lib = device_->newLibrary(ns_code, options, &error); options->release(); diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 5206c9b54a..7db84e244b 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -745,7 +745,8 @@ MTL::ComputePipelineState* get_quantized_kernel( const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; - kernel_source << metal::utils() << metal::gemm() << metal::quantized() + kernel_source << "#include " + << metal::utils() << metal::gemm() << metal::quantized() << template_def; return kernel_source.str(); }); diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index b2b0d8d8f9..c134d159b4 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -2,23 +2,42 @@ #include #include +#include constant bool align_M [[function_constant(200)]]; constant bool align_N [[function_constant(201)]]; constant bool align_K [[function_constant(202)]]; using namespace metal; +using namespace mpp::tensor_ops; #define MLX_MTL_CONST static constant constexpr const MLX_MTL_CONST int SIMD_SIZE = 32; MLX_MTL_CONST int QUAD_SIZE = 4; +template struct infer_tg_type; +template<> struct infer_tg_type { using type = half; }; +template<> struct infer_tg_type { using type = float; }; +template<> struct infer_tg_type { using type = float; }; + +template +inline constexpr short get_pack_factor() { + return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits); +} + +template +inline constexpr short get_bytes_per_pack() { + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3); +} + template inline U load_vector(const device T* x, thread U* x_thread) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U sum = 0; @@ -57,6 +76,21 @@ inline U load_vector(const device T* x, thread U* x_thread) { } } + else if (bits == 5) { + for (int i = 0; i < values_per_thread; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 32.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 128.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 2.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 8.0f; + } + } + else if (bits == 6) { for (int i = 0; i < values_per_thread; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; @@ -80,8 +114,9 @@ inline U load_vector(const device T* x, thread U* x_thread) { template inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U sum = 0; @@ -121,6 +156,21 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { } } + else if (bits == 5) { + for (int i = 0; i < N; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 32.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 128.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 2.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 8.0f; + } + } + else if (bits == 6) { for (int i = 0; i < N; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; @@ -153,8 +203,9 @@ inline U qdot( U bias, U sum) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U accum = 0; @@ -199,6 +250,26 @@ inline U qdot( } } + else if (bits == 5) { + for (int i = 0; i < (values_per_thread / 8); i++) { + x_thread += 8 * i; + w += 5 * i; + + accum += (w[0] & 0x1f) * x_thread[0]; + accum += (w[0] & 0xe0) * x_thread[1]; + accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); + accum += (w[1] & 0x7c) * x_thread[2]; + accum += (w[1] & 0x80) * x_thread[3]; + accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); + accum += (w[2] & 0xf0) * x_thread[4]; + accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); + accum += (w[3] & 0x3e) * x_thread[5]; + accum += (w[3] & 0xc0) * x_thread[6]; + accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); + accum += (w[4] & 0xf8) * x_thread[7]; + } + } + else if (bits == 6) { for (int i = 0; i < (values_per_thread / 4); i++) { x_thread += 4 * i; @@ -234,8 +305,9 @@ inline U qdot_safe( U sum, int N) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U accum = 0; @@ -280,6 +352,26 @@ inline U qdot_safe( } } + else if (bits == 5) { + for (int i = 0; i < (N / 8); i++) { + x_thread += 8 * i; + w += 5 * i; + + accum += (w[0] & 0x1f) * x_thread[0]; + accum += (w[0] & 0xe0) * x_thread[1]; + accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); + accum += (w[1] & 0x7c) * x_thread[2]; + accum += (w[1] & 0x80) * x_thread[3]; + accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); + accum += (w[2] & 0xf0) * x_thread[4]; + accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); + accum += (w[3] & 0x3e) * x_thread[5]; + accum += (w[3] & 0xc0) * x_thread[6]; + accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); + accum += (w[4] & 0xf8) * x_thread[7]; + } + } + else if (bits == 6) { for (int i = 0; i < (N / 4); i++) { x_thread += 4 * i; @@ -310,8 +402,9 @@ template inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); if (bits == 2) { U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; @@ -348,8 +441,31 @@ qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias); result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias); } + } - } else if (bits == 6) { + else if (bits == 5) { + for (int i = 0; i < (values_per_thread / 8); i++) { + uint8_t w0 = w[5 * i]; + uint8_t w1 = w[5 * i + 1]; + uint8_t w2 = w[5 * i + 2]; + uint8_t w3 = w[5 * i + 3]; + uint8_t w4 = w[5 * i + 4]; + result[8 * i] += x * ((w0 & 0x1f) * scale + bias); + result[8 * i + 1] += + x * ((((w0 & 0xe0) >> 5) + ((w1 & 0x3) << 3)) * scale + bias); + result[8 * i + 2] += x * (((w1 & 0x7c) >> 2) * scale + bias); + result[8 * i + 3] += + x * ((((w1 & 0x80) >> 7) + ((w2 & 0xf) << 1)) * scale + bias); + result[8 * i + 4] += + x * ((((w2 & 0xf0) >> 4) + ((w3 & 0x1) << 4)) * scale + bias); + result[8 * i + 5] += x * (((w3 & 0x3e) >> 1) * scale + bias); + result[8 * i + 6] += + x * ((((w3 & 0xc0) >> 6) + ((w4 & 0x7) << 2)) * scale + bias); + result[8 * i + 7] += x * (((w4 & 0xf8) >> 3) * scale + bias); + } + } + + else if (bits == 6) { for (int i = 0; i < (values_per_thread / 4); i++) { uint8_t w0 = w[3 * i]; uint8_t w1 = w[3 * i + 1]; @@ -371,12 +487,13 @@ qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { } } -template +template inline void -dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { +dequantize(const device uint8_t* w, U scale, U bias, threadgroup TgType* w_local) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); if (bits == 2) { U s[4] = { @@ -416,11 +533,26 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { } } + else if (bits == 5) { + for (int i = 0; i < (N / 8); i++) { + w_local += 8 * i; + w += 5 * i; + + w_local[0] = (w[0] & 0x1f) * scale + bias; + w_local[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias; + w_local[2] = ((w[1] & 0x7c) >> 2) * scale + bias; + w_local[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias; + w_local[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias; + w_local[5] = ((w[3] & 0x3e) >> 1) * scale + bias; + w_local[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias; + w_local[7] = ((w[4] & 0xf8) >> 3) * scale + bias; + } + } + else if (bits == 6) { for (int i = 0; i < (N / 4); i++) { w_local += 4 * i; w += 3 * i; - w_local[0] = (w[0] & 0x3f) * scale + bias; w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias; w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias; @@ -437,6 +569,7 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { template < typename T, + typename TgType, short BROWS, short BCOLS, short dst_ld, @@ -452,11 +585,12 @@ struct QuantizedBlockLoader { group_size % BCOLS == 0, "The group size should be divisible by the columns"); static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); - MLX_MTL_CONST short pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; - MLX_MTL_CONST short bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1; + MLX_MTL_CONST short pack_factor = get_pack_factor(); + MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; MLX_MTL_CONST short n_reads = (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; @@ -471,7 +605,7 @@ struct QuantizedBlockLoader { const short bi; const short bj; - threadgroup T* dst; + threadgroup TgType* dst; const device uint8_t* src; const device T* scales; const device T* biases; @@ -481,7 +615,7 @@ struct QuantizedBlockLoader { const device T* scales_, const device T* biases_, const int src_ld_, - threadgroup T* dst_, + threadgroup TgType* dst_, ushort simd_group_id [[simdgroup_index_in_threadgroup]], ushort simd_lane_id [[thread_index_in_simdgroup]]) : src_ld(src_ld_), @@ -507,7 +641,7 @@ struct QuantizedBlockLoader { T scale = *scales; T bias = *biases; for (int i = 0; i < n_reads; i++) { - dequantize( + dequantize( src + i * bytes_per_pack, scale, bias, dst + i * pack_factor); } } @@ -517,16 +651,9 @@ struct QuantizedBlockLoader { return; } - if (reduction_dim == 1 && bi >= src_tile_dim.y) { + if (bi >= src_tile_dim.y) { for (int i = 0; i < n_reads * pack_factor; i++) { - dst[i] = T(0); - } - return; - } - - if (reduction_dim == 0 && bi >= src_tile_dim.x) { - for (int i = 0; i < n_reads * pack_factor; i++) { - dst[i] = T(0); + dst[i] = TgType(0); } return; } @@ -534,7 +661,7 @@ struct QuantizedBlockLoader { T scale = *scales; T bias = *biases; for (int i = 0; i < n_reads; i++) { - dequantize( + dequantize( (device uint8_t*)(src + i * bytes_per_pack), scale, bias, @@ -632,12 +759,11 @@ METAL_FUNC void qmv_fast_impl( uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int packs_per_thread = bits == 2 ? 1 : 2; constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; - constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int values_per_thread = pack_factor * packs_per_thread; constexpr int block_size = values_per_thread * SIMD_SIZE; constexpr int scale_step_per_thread = group_size / values_per_thread; @@ -700,12 +826,12 @@ METAL_FUNC void qmv_impl( uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; constexpr int packs_per_thread = 1; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; - constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int values_per_thread = pack_factor * packs_per_thread; constexpr int block_size = values_per_thread * SIMD_SIZE; constexpr int scale_step_per_thread = group_size / values_per_thread; @@ -857,8 +983,9 @@ METAL_FUNC void qvm_impl( uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int num_simdgroups = 2; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int tn = 32 / pack_factor; constexpr int block_size = SIMD_SIZE; @@ -953,6 +1080,7 @@ METAL_FUNC void qvm_impl( template < typename T, + typename TgType, const int group_size, const int bits, const bool aligned_N, @@ -965,8 +1093,8 @@ METAL_FUNC void qmm_t_impl( const device T* biases, const device T* x, device T* y, - threadgroup T* Xs, - threadgroup T* Ws, + threadgroup TgType* Xs, + threadgroup TgType* Ws, const constant int& K, const constant int& N, const constant int& M, @@ -981,17 +1109,16 @@ METAL_FUNC void qmm_t_impl( constexpr int WM = 2; constexpr int WN = 2; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int BK_padded = (BK + 16 / sizeof(T)); - constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1; - // Instantiate the appropriate BlockMMA and Loader - using mma_t = mlx::steel:: - BlockMMA; using loader_x_t = - mlx::steel::BlockLoader; + mlx::steel::ConversionBlockLoader; using loader_w_t = QuantizedBlockLoader< T, + TgType, BN, BK, BK_padded, @@ -1008,49 +1135,65 @@ METAL_FUNC void qmm_t_impl( auto wl = (const device uint8_t*)w; - x += y_row * K; + x += y_row * static_cast(K); wl += y_col * K_w; scales += y_col * K_g; biases += y_col * K_g; - y += y_row * N + y_col; + y += y_row * static_cast(N) + y_col; - // Make the x loader and mma operation - const short num_els = min(BM, M - y_row); - const short num_outs = min(BN, N - y_col); + const short tile_m = min(BM, M - y_row); + const short tile_n = min(BN, N - y_col); loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid); - mma_t mma_op(simd_gid, simd_lid); - if (num_els < BM) { - if (!aligned_N && num_outs < BN) { + using TgMatrix = tensor, tensor_inline>; + dextents dims_tg_x{BK, BM}; + dextents dims_tg_w{BK, BN}; + dextents dims_y{BN, BM}; + array strides_tg_x{1, BK_padded}; + array strides_tg_w{1, BK_padded}; + array strides_y{1, N}; + TgMatrix tX{Xs, dims_tg_x, strides_tg_x}; + TgMatrix tW{Ws, dims_tg_w, strides_tg_w}; + tensor, tensor_inline> tY{y, dims_y, strides_y}; + + using AccumType = float; + constexpr auto relax_prec = true; + constexpr auto mmul_mode = matmul2d_descriptor::mode::multiply_accumulate; + constexpr auto desc = matmul2d_descriptor(BM, BN, BK, false, true, relax_prec, mmul_mode); + matmul2d> op; + auto rY = op.template get_destination_cooperative_tensor(); + + if (tile_m < BM) { + if (!aligned_N && tile_n < BN) { for (int k = 0; k < K; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); - loader_x.load_safe(short2(BK, num_els)); - loader_w.load_safe(short2(BK, num_outs)); + loader_x.load_safe(short2(BK, tile_m)); + loader_w.load_safe(short2(BK, tile_n)); threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); + op.run(tX, tW, rY); loader_x.next(); loader_w.next(); } } else { for (int k = 0; k < K; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); - loader_x.load_safe(short2(BK, num_els)); + loader_x.load_safe(short2(BK, tile_m)); loader_w.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); + op.run(tX, tW, rY); loader_x.next(); loader_w.next(); } } } else { - if (!aligned_N && num_outs < BN) { + if (!aligned_N && tile_n < BN) { for (int k = 0; k < K; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_unsafe(); - loader_w.load_safe(short2(BK, num_outs)); + loader_w.load_safe(short2(BK, tile_n)); threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); + op.run(tX, tW, rY); loader_x.next(); loader_w.next(); } @@ -1060,8 +1203,7 @@ METAL_FUNC void qmm_t_impl( loader_x.load_unsafe(); loader_w.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); - - mma_op.mma(Xs, Ws); + op.run(tX, tW, rY); loader_x.next(); loader_w.next(); } @@ -1070,15 +1212,18 @@ METAL_FUNC void qmm_t_impl( // Store results to device memory threadgroup_barrier(mem_flags::mem_threadgroup); - if (num_els < BM || num_outs < BN) { - mma_op.store_result_safe(y, N, short2(num_outs, num_els)); - } else { - mma_op.store_result(y, N); + #pragma clang loop unroll(full) + for (int i = 0; i < rY.get_capacity(); ++i) { + auto pos = rY.get_multidimensional_index(i); + if ((pos[1] < tile_m) && (pos[0] < tile_n)) { + tY[pos] = T(rY[i]); + } } } template < typename T, + typename TgType, const int group_size, const int bits, const int BM = 32, @@ -1090,8 +1235,8 @@ METAL_FUNC void qmm_n_impl( const device T* biases, const device T* x, device T* y, - threadgroup T* Xs, - threadgroup T* Ws, + threadgroup TgType* Xs, + threadgroup TgType* Ws, const constant int& K, const constant int& N, const constant int& M, @@ -1106,19 +1251,17 @@ METAL_FUNC void qmm_n_impl( constexpr int WM = 2; constexpr int WN = 2; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; - // Instantiate the appropriate BlockMMA and Loader - using mma_t = mlx::steel:: - BlockMMA; - using loader_x_t = mlx::steel:: - BlockLoader; + using loader_x_t = + mlx::steel::ConversionBlockLoader; using loader_w_t = QuantizedBlockLoader< T, + TgType, BK, BN, BN_padded, @@ -1132,17 +1275,34 @@ METAL_FUNC void qmm_n_impl( // Set the block const int y_row = tid.y * BM; const int y_col = tid.x * BN; - x += y_row * K; + x += y_row * static_cast(K); wl += y_col * bytes_per_pack / pack_factor; scales += y_col / group_size; biases += y_col / group_size; - y += y_row * N + y_col; + y += y_row * static_cast(N) + y_col; // Make the x loader and mma operation const short num_els = min(BM, M - y_row); loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); loader_w_t loader_w(wl, scales, biases, N, Ws, simd_gid, simd_lid); - mma_t mma_op(simd_gid, simd_lid); + + using TgMatrix = tensor, tensor_inline>; + dextents dims_tg_x{BK, BM}; + dextents dims_tg_w{BN, BK}; + dextents dims_y{BN, BM}; + array strides_tg_x{1, BK_padded}; + array strides_tg_w{1, BN_padded}; + array strides_y{1, N}; + TgMatrix tX{Xs, dims_tg_x, strides_tg_x}; + TgMatrix tW{Ws, dims_tg_w, strides_tg_w}; + tensor, tensor_inline> tY{y, dims_y, strides_y}; + + using AccumType = float; + constexpr auto relax_prec = true; + constexpr auto mmul_mode = matmul2d_descriptor::mode::multiply_accumulate; + constexpr auto desc = matmul2d_descriptor(BM, BN, BK, false, false, relax_prec, mmul_mode); + matmul2d> op; + auto rY = op.template get_destination_cooperative_tensor(); if (num_els < BM) { if ((K % BK) != 0) { @@ -1152,7 +1312,7 @@ METAL_FUNC void qmm_n_impl( loader_x.load_safe(short2(BK, num_els)); loader_w.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); + op.run(tX, tW, rY); loader_x.next(); loader_w.next(); } @@ -1161,14 +1321,14 @@ METAL_FUNC void qmm_n_impl( loader_x.load_safe(short2(num_k, num_els)); loader_w.load_safe(short2(BN, num_k)); threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); + op.run(tX, tW, rY); } else { for (int k = 0; k < K; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_safe(short2(BK, num_els)); loader_w.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); + op.run(tX, tW, rY); loader_x.next(); loader_w.next(); } @@ -1181,7 +1341,7 @@ METAL_FUNC void qmm_n_impl( loader_x.load_unsafe(); loader_w.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); + op.run(tX, tW, rY); loader_x.next(); loader_w.next(); } @@ -1190,14 +1350,14 @@ METAL_FUNC void qmm_n_impl( loader_x.load_safe(short2(num_k, BM)); loader_w.load_safe(short2(BN, num_k)); threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); + op.run(tX, tW, rY); } else { for (int k = 0; k < K; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_unsafe(); loader_w.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); + op.run(tX, tW, rY); loader_x.next(); loader_w.next(); } @@ -1206,10 +1366,12 @@ METAL_FUNC void qmm_n_impl( // Store results to device memory threadgroup_barrier(mem_flags::mem_threadgroup); - if (num_els < BM) { - mma_op.store_result_safe(y, N, short2(BN, num_els)); - } else { - mma_op.store_result(y, N); + #pragma clang loop unroll(full) + for (int i = 0; i < rY.get_capacity(); ++i) { + auto pos = rY.get_multidimensional_index(i); + if (pos[1] < num_els) { + tY[pos] = T(rY[i]); + } } } @@ -1576,9 +1738,9 @@ template < const int bits, const bool aligned_N, const bool batched, - const int BM = 32, + const int BM = 64, const int BK = 32, - const int BN = 32> + const int BN = 64> [[kernel]] void qmm_t( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], @@ -1604,8 +1766,9 @@ template < constexpr int BK_padded = (BK + 16 / sizeof(T)); - threadgroup T Xs[BM * BK_padded]; - threadgroup T Ws[BN * BK_padded]; + using TgType = typename infer_tg_type::type; + threadgroup TgType Xs[BM * BK_padded]; + threadgroup TgType Ws[BN * BK_padded]; if (batched) { adjust_matrix_offsets( @@ -1625,7 +1788,7 @@ template < b_strides, tid); } - qmm_t_impl( + qmm_t_impl( w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } @@ -1663,8 +1826,9 @@ template < constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); - threadgroup T Xs[BM * BK_padded]; - threadgroup T Ws[BK * BN_padded]; + using TgType = typename infer_tg_type::type; + threadgroup TgType Xs[BM * BK_padded]; + threadgroup TgType Ws[BK * BN_padded]; if (batched) { adjust_matrix_offsets( @@ -1685,7 +1849,7 @@ template < tid); } - qmm_n_impl( + qmm_n_impl( w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } @@ -1914,8 +2078,9 @@ template < constexpr int BK_padded = (BK + 16 / sizeof(T)); - threadgroup T Xs[BM * BK_padded]; - threadgroup T Ws[BN * BK_padded]; + using TgType = typename infer_tg_type::type; + threadgroup TgType Xs[BM * BK_padded]; + threadgroup TgType Ws[BN * BK_padded]; adjust_matrix_offsets( x, @@ -1939,7 +2104,7 @@ template < s_strides, b_strides, tid); - qmm_t_impl( + qmm_t_impl( w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } @@ -1982,8 +2147,9 @@ template < constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); - threadgroup T Xs[BM * BK_padded]; - threadgroup T Ws[BK * BN_padded]; + using TgType = typename infer_tg_type::type; + threadgroup TgType Xs[BM * BK_padded]; + threadgroup TgType Ws[BK * BN_padded]; adjust_matrix_offsets( x, @@ -2007,96 +2173,10 @@ template < s_strides, b_strides, tid); - qmm_n_impl( + qmm_n_impl( w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } -template -METAL_FUNC void gemm_loop_aligned( - threadgroup T* As, - threadgroup T* Bs, - thread mma_t& mma_op, - thread loader_a_t& loader_a, - thread loader_b_t& loader_b, - const int k_iterations) { - for (int k = 0; k < k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Load elements into threadgroup memory - loader_a.load_unsafe(); - loader_b.load_unsafe(); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } -} - -template < - bool rows_aligned, - bool cols_aligned, - bool transpose, - typename T, - typename mma_t, - typename loader_a_t, - typename loader_b_t> -METAL_FUNC void gemm_loop_unaligned( - threadgroup T* As, - threadgroup T* Bs, - thread mma_t& mma_op, - thread loader_a_t& loader_a, - thread loader_b_t& loader_b, - const int k_iterations, - const short tgp_bm, - const short tgp_bn, - const short tgp_bk) { - for (int k = 0; k < k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Load elements into threadgroup memory - if (rows_aligned) { - loader_a.load_unsafe(); - } else { - loader_a.load_safe(short2(tgp_bk, tgp_bm)); - } - if (cols_aligned) { - loader_b.load_unsafe(); - } else { - loader_b.load_safe( - transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk)); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } -} - -template -METAL_FUNC void gemm_loop_finalize( - threadgroup T* As, - threadgroup T* Bs, - thread mma_t& mma_op, - thread loader_a_t& loader_a, - thread loader_b_t& loader_b, - const short2 tile_a, - const short2 tile_b) { - loader_a.load_safe(tile_a); - loader_b.load_safe(tile_b); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(As, Bs); -} - template < typename T, int group_size, @@ -2120,28 +2200,18 @@ template < uint3 tid [[threadgroup_position_in_grid]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]]) { - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; - using mma_t = mlx::steel::BlockMMA< - T, - T, - BM, - BN, - BK, - WM, - WN, - false, - transpose, - BK_padded, - transpose ? BK_padded : BN_padded>; + using TgType = typename infer_tg_type::type; + using loader_x_t = - mlx::steel::BlockLoader; + mlx::steel::ConversionBlockLoader; using loader_w_t = QuantizedBlockLoader< T, + TgType, transpose ? BN : BK, transpose ? BK : BN, transpose ? BK_padded : BN_padded, @@ -2150,8 +2220,8 @@ template < group_size, bits>; - threadgroup T Xs[BM * BK_padded]; - threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded]; + threadgroup TgType Xs[BM * BK_padded]; + threadgroup TgType Ws[transpose ? BN * BK_padded : BK * BN_padded]; // Compute the block const int K_w = K * bytes_per_pack / pack_factor; @@ -2184,6 +2254,19 @@ template < scales += transpose ? y_col_long * K_g : y_col / group_size; biases += transpose ? y_col_long * K_g : y_col / group_size; + // Create threadgroup input matrices and device output matrix + using TgMatrix = tensor, tensor_inline>; + dextents dims_tg_x{BK, BM}; + dextents dims_tg_w{(transpose) ? BK : BN, + (transpose) ? BN : BK}; + dextents dims_y{BN, BM}; + array strides_tg_x{1, BK_padded}; + array strides_tg_w{1, (transpose) ? BK_padded : BN_padded}; + array strides_y{1, N}; + TgMatrix tX{Xs, dims_tg_x, strides_tg_x}; + TgMatrix tW{Ws, dims_tg_w, strides_tg_w}; + tensor, tensor_inline> tY{y, dims_y, strides_y}; + // Do as many matmuls as necessary uint32_t index; short offset; @@ -2204,8 +2287,13 @@ template < } threadgroup_barrier(mem_flags::mem_none); - // Prepare threadgroup mma operation - thread mma_t mma_op(simd_group_id, simd_lane_id); + // Create matmul operation + using AccumType = float; + constexpr auto relax_prec = true; + constexpr auto mmul_mode = matmul2d_descriptor::mode::multiply_accumulate; + constexpr auto desc = matmul2d_descriptor(BM, BN, BK, false, transpose, relax_prec, mmul_mode); + matmul2d> op; + auto rY = op.template get_destination_cooperative_tensor(); // Prepare threadgroup loading operations thread loader_x_t loader_x(x, K, Xs, simd_group_id, simd_lane_id); @@ -2218,77 +2306,62 @@ template < simd_group_id, simd_lane_id); - // Matrices are all aligned check nothing - if (align_M && align_N) { - gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it); - if (!align_K) { + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + for (int k = 0; k < K_it; k++) { threadgroup_barrier(mem_flags::mem_threadgroup); - gemm_loop_finalize(Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); - } - - // Store results to device memory - if (offset_next - offset == BM) { - mma_op.store_result(y, N); - } else { - mma_op.store_result_slice( - y, N, short2(0, offset), short2(BN, offset_next)); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + op.run(tX, tW, rY); + loader_x.next(); + loader_w.next(); } - } else { - // Tile aligned so check outside of the hot loop - if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { - gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it); - if (!align_K) { - threadgroup_barrier(mem_flags::mem_threadgroup); - gemm_loop_finalize( - Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); - } - - // Store results to device memory - if (offset_next - offset == BM) { - mma_op.store_result(y, N); - } else { - mma_op.store_result_slice( - y, N, short2(0, offset), short2(BN, offset_next)); - } + } else if (align_N || tgp_bn == BN) { + for (int k = 0; k < K_it; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, tgp_bm)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + op.run(tX, tW, rY); + loader_x.next(); + loader_w.next(); } - - // Tile partially aligned check rows - else if (align_N || tgp_bn == BN) { - gemm_loop_unaligned( - Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); - if (!align_K) { - threadgroup_barrier(mem_flags::mem_threadgroup); - gemm_loop_finalize( - Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); - } - mma_op.store_result_slice( - y, N, short2(0, offset), short2(BN, offset_next)); + } else if (align_M || tgp_bm == BM) { + for (int k = 0; k < K_it; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_safe(short2((transpose) ? BK : tgp_bn, + (transpose) ? tgp_bn : BK)); + threadgroup_barrier(mem_flags::mem_threadgroup); + op.run(tX, tW, rY); + loader_x.next(); + loader_w.next(); } - - // Tile partially aligned check cols - else if (align_M || tgp_bm == BM) { - gemm_loop_unaligned( - Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); - if (!align_K) { - threadgroup_barrier(mem_flags::mem_threadgroup); - gemm_loop_finalize( - Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); - } - mma_op.store_result_slice( - y, N, short2(0, offset), short2(tgp_bn, offset_next)); + } else { + for (int k = 0; k < K_it; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, tgp_bm)); + loader_w.load_safe(short2((transpose) ? BK : tgp_bn, + (transpose) ? tgp_bn : BK)); + threadgroup_barrier(mem_flags::mem_threadgroup); + op.run(tX, tW, rY); + loader_x.next(); + loader_w.next(); } + } - // Nothing aligned so check both rows and cols - else { - gemm_loop_unaligned( - Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); - if (!align_K) { - threadgroup_barrier(mem_flags::mem_threadgroup); - gemm_loop_finalize( - Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); - } - mma_op.store_result_slice( - y, N, short2(0, offset), short2(tgp_bn, offset_next)); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(tile_x); + loader_w.load_safe(tile_w); + threadgroup_barrier(mem_flags::mem_threadgroup); + op.run(tX, tW, rY); + } + #pragma clang loop unroll(full) + for (int i = 0; i < rY.get_capacity(); ++i) { + auto pos = rY.get_multidimensional_index(i); + if ((pos[1] >= offset) && (pos[1] < offset_next) && (pos[0] < tgp_bn)) { + tY[pos] = T(rY[i]); } } } @@ -2305,13 +2378,13 @@ template constexpr float eps = 1e-7; constexpr int simd_size = 32; constexpr float n_bins = (1 << bits) - 1; - constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int values_per_reduce = group_size / simd_size; - constexpr int writes_per_reduce = packs_per_int / values_per_reduce; + constexpr int writes_per_reduce = pack_factor / values_per_reduce; constexpr int writes_per_pack = - writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int; + writes_per_reduce > 1 ? 1 : values_per_reduce / pack_factor; constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; static_assert( group_size % simd_size == 0, @@ -2354,8 +2427,8 @@ template biases[gindex] = static_cast(bias); } - // We accumulate 3 bytes worth for 3/6 bit so we need a uint32_t - uint32_t output = 0; + using OutType = metal::conditional_t; + OutType output = 0; #pragma clang loop unroll(full) for (int i = 0; i < values_per_reduce; i++) { @@ -2363,27 +2436,35 @@ template if (bits == 8) { output = val; } else { - output += val << (bits * (i % packs_per_int)); + output |= val << (bits * (i % pack_factor)); } - if (packs_per_int < values_per_reduce && - i % packs_per_int == packs_per_int - 1) { - out[out_index + i / packs_per_int] = output; + if (pack_factor < values_per_reduce && i % pack_factor == pack_factor - 1) { + out[out_index + i / pack_factor] = output; output = 0; } else { #pragma clang loop unroll(full) for (int j = 1; j < writes_per_reduce; j++) { uint8_t sval = simd_shuffle_down(val, j); - output += sval << (bits * (j * values_per_reduce + i)); + output |= static_cast(sval) + << (bits * (j * values_per_reduce + i)); } } } if (bits == 3 || bits == 6) { - if (in_index % packs_per_int == 0 && out_index % bytes_per_pack == 0) { + if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) { out[out_index] = output & 0xff; out[out_index + 1] = (output & 0xff00) >> 8; out[out_index + 2] = (output & 0xff0000) >> 16; } + } else if (bits == 5) { + if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) { + out[out_index] = output & 0xff; + out[out_index + 1] = (output & 0xff00) >> 8; + out[out_index + 2] = (output & 0xff0000) >> 16; + out[out_index + 3] = (output & 0xff000000) >> 24; + out[out_index + 4] = (output & 0xff00000000) >> 32; + } } else { if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) { out[out_index / writes_per_reduce] = output; @@ -2399,12 +2480,11 @@ template device T* out [[buffer(3)]], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); size_t offset = index.x + grid_dim.x * size_t(index.y); - size_t oindex = offset * packs_per_int; + size_t oindex = offset * pack_factor; size_t gindex = oindex / group_size; T scale = scales[gindex]; T bias = biases[gindex]; @@ -2421,7 +2501,16 @@ template out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias; out[6] = ((w[2] & 0x1c) >> 2) * scale + bias; out[7] = ((w[2] & 0xe0) >> 5) * scale + bias; - + } else if (bits == 5) { + w += offset * bytes_per_pack; + out[0] = (w[0] & 0x1f) * scale + bias; + out[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias; + out[2] = ((w[1] & 0x7c) >> 2) * scale + bias; + out[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias; + out[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias; + out[5] = ((w[3] & 0x3e) >> 1) * scale + bias; + out[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias; + out[7] = ((w[4] & 0xf8) >> 3) * scale + bias; } else if (bits == 6) { w += offset * bytes_per_pack; out[0] = (w[0] & 0x3f) * scale + bias; @@ -2431,7 +2520,7 @@ template } else { uint val = w[offset]; #pragma clang loop unroll(full) - for (int i = 0; i < packs_per_int; i++) { + for (int i = 0; i < pack_factor; i++) { uint8_t d; if (bits == 2) { d = (val >> (bits * i)) & 0x03; diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index 11cd8421bf..4334f8d74e 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -111,8 +111,8 @@ instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32) #define instantiate_quantized_all_rhs(type, group_size, bits) \ - instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nt, type, group_size, bits, 16, 32, 32, 1, 2, true) \ - instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nn, type, group_size, bits, 16, 32, 32, 1, 2, false) + instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nt, type, group_size, bits, 32, 32, 32, 2, 2, true) \ + instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nn, type, group_size, bits, 32, 32, 32, 2, 2, false) #define instantiate_quantized_funcs(type, group_size, bits) \ instantiate_quantized_all_single(type, group_size, bits) \ diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h index 2e27ea06f9..980ff40ac2 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h @@ -2,6 +2,10 @@ using namespace mlx::steel; +#include + +using namespace mpp::tensor_ops; + /////////////////////////////////////////////////////////////////////////////// // GEMM kernels /////////////////////////////////////////////////////////////////////////////// @@ -462,3 +466,284 @@ template < Otile.template store(O, params->O_strides[2]); } } + +// clang-format off +template < + typename T, + int BQ, + int BK, + int BD, + int WM, + int WN, + typename MaskType = float, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention_tops( + device T* Q [[buffer(0)]], + device T* K [[buffer(1)]], + device T* V [[buffer(2)]], + device T* O [[buffer(3)]], + const constant mlx::steel::AttnParams* params [[buffer(4)]], + const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]], + const device MaskType* mask [[buffer(6), function_constant(has_mask)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on + + // Pacifying compiler + (void)lid; + + const int blockQ = tid.x; // Index of the block along Q dimension for this threadgroup + const int q_head = tid.y; + const int kv_head = q_head / params->gqa_factor; + const int batch_id = tid.z; + + const int D = BD; + + const int ROWS_PER_SIMD = BQ / WM; // assumes WN = 1 + + // Offset to this threadgroup's tile + Q += params->Q_strides[0] * batch_id + params->Q_strides[1] * q_head + params->Q_strides[2] * BQ * blockQ; + K += params->K_strides[0] * batch_id + params->K_strides[1] * kv_head; + V += params->V_strides[0] * batch_id + params->V_strides[1] * kv_head; + O += params->O_strides[0] * batch_id + params->O_strides[1] * q_head + params->O_strides[2] * BQ * blockQ; + + Q += params->Q_strides[2] * simd_group_id * ROWS_PER_SIMD; + + int kb_lim = params->NK; + + if (do_causal) { + int q_max = (tid.x + 1) * BQ + params->qL_off; + kb_lim = (q_max + BK - 1) / BK; + kb_lim = min(params->NK, kb_lim); + } + + // Each threadgroup will compute a block of Q * K^T into threadgroup memory + // [ BQ * D ] * [ D * BK ] -> [ BQ * BK ] + threadgroup float tg_QKT[BQ * BK]; + + float softmax_scale = params->scale; + softmax_scale *= 1.44269504089f; // to use exp2 + + using QKVTensorType = tensor, tensor_inline>; + using STensorType = tensor, tensor_inline>; + + const int MUL_SV_BD = D <= 80 ? D : 64; + + // Q * K^T + constexpr auto mul_qkt_op_desc = mpp::tensor_ops::matmul2d_descriptor(/* M = */ ROWS_PER_SIMD, + /* N = */ BK, + /* K = */ D, + /* transposeLeft = */ false, + /* transposeRight = */ true, + /* relaxedPrecision = */ true, + /* mode = */ mpp::tensor_ops::matmul2d_descriptor::mode::multiply); + + mpp::tensor_ops::matmul2d mul_qkt_op; + + // Op which multiplies SoftMax result by V + constexpr auto mul_sv_op_desc = mpp::tensor_ops::matmul2d_descriptor(/* M = */ ROWS_PER_SIMD, + /* N = */ MUL_SV_BD, + /* K = */ BK, + /* transposeLeft = */ false, + /* transposeRight = */ false, + /* relaxedPrecision = */ true, + /* mode = */ mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate); + + mpp::tensor_ops::matmul2d mul_sv_op; + + auto ct_row_sum = mul_qkt_op.template get_row_reduction_destination_cooperative_tensor(); + auto ct_row_max = mul_qkt_op.template get_row_reduction_destination_cooperative_tensor(); + + #pragma clang loop unroll(full) + for (int i = 0; i < ct_row_sum.get_capacity(); i++) { + // ct_row_sum already cleared to zero + ct_row_max[i] = -INFINITY; + } + + // Cooperative output matrix accumulator + auto ctO_0 = mul_sv_op.template get_destination_cooperative_tensor(); + auto ctO_1 = mul_sv_op.template get_destination_cooperative_tensor(); + + STensorType tQKT(tg_QKT + simd_group_id * ROWS_PER_SIMD * BK, dextents(BK, BQ), array({ 1, BK })); + + dextents QExtents = dextents(D, BQ); + + if (!align_Q && blockQ == params->NQ_aligned) + QExtents = dextents(D, params->qL_rem); + + // Outer loop over blocks along K. Equivalent to looping over columns of Q * K^T + for (int k = 0; k < kb_lim; k++) { + dextents KVExtents = dextents(D, BK); + + if (!align_K && k == params->NK_aligned) + KVExtents = dextents(D, params->kL_rem); + + QKVTensorType tQ(Q, QExtents, array({ 1, static_cast(params->Q_strides[2]) })); + QKVTensorType tK(K, KVExtents, array({ 1, static_cast(params->K_strides[2]) })); + QKVTensorType tV(V, KVExtents, array({ 1, static_cast(params->V_strides[2]) })); + + auto cTQKT = mul_qkt_op.template get_destination_cooperative_tensor(); + + mul_qkt_op.run(tQ, tK, cTQKT); + + #pragma clang loop unroll(full) + for (int i = 0; i < cTQKT.get_capacity(); i++) + cTQKT[i] *= softmax_scale; + + // Mask out length sequence + if (!align_K && k == (params->NK_aligned)) { + #pragma clang loop unroll(full) + for (int i = 0; i < cTQKT.get_capacity(); i++) { + auto idxs = cTQKT.get_multidimensional_index(i); + + cTQKT[i] = idxs[0] >= params->kL_rem ? -INFINITY : cTQKT[i]; + } + } + + // causal masking + if (do_causal && k >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) { + int row_base = blockQ * BQ + params->qL_off + simd_group_id * ROWS_PER_SIMD; + int col_base = k * BK; + + #pragma clang loop unroll(full) + for (int i = 0; i < cTQKT.get_capacity(); i++) { + auto idxs = cTQKT.get_multidimensional_index(i); + + int row = row_base + idxs[1]; + int col = col_base + idxs[0]; + + cTQKT[i] = row < col ? -INFINITY : cTQKT[i]; + } + } + + auto ct_tile_row_max = mul_qkt_op.template get_row_reduction_destination_cooperative_tensor(); + + reduce_rows(cTQKT, ct_tile_row_max, reduction_operation::max, -INFINITY); + + auto ct_row_scale = mul_qkt_op.template get_row_reduction_destination_cooperative_tensor(); + + // Compute SoftMax for each row of this SIMD group + #pragma clang loop unroll(full) + for (int i = 0; i < ct_tile_row_max.get_capacity(); i++) { + float new_max = max(ct_tile_row_max[i], ct_row_max[i]); + + ct_row_scale[i] = fast::exp2(ct_row_max[i] - new_max); + + ct_row_max[i] = new_max; + } + + #pragma clang loop unroll(full) + for (auto it = cTQKT.begin(); it != cTQKT.end(); it++) { + auto row_it = ct_row_max.map_iterator(it); + + *it = fast::exp2(*it - *row_it); + } + + auto ct_tile_row_sum = mul_qkt_op.template get_row_reduction_destination_cooperative_tensor(); + + reduce_rows(cTQKT, ct_tile_row_sum, reduction_operation::sum, 0.0f); + + #pragma clang loop unroll(full) + for (uint i = 0; i < ct_tile_row_sum.get_capacity(); i++) + ct_row_sum[i] = ct_row_sum[i] * ct_row_scale[i] + ct_tile_row_sum[i]; + + #pragma clang loop unroll(full) + for (auto it = ctO_0.begin(); it != ctO_0.end(); it++) { + auto row_it = ct_row_scale.map_iterator(it); + + *it *= *row_it; + } + + if (D == 128) { + #pragma clang loop unroll(full) + for (auto it = ctO_1.begin(); it != ctO_1.end(); it++) { + auto row_it = ct_row_scale.map_iterator(it); + + *it *= *row_it; + } + } + + // Wait for other SIMD groups to finish reading threadgroup memory before we clobber it + simdgroup_barrier(mem_flags::mem_threadgroup); + + cTQKT.store(tQKT); // elided by backend optimization + + // Wait for store to complete + simdgroup_barrier(mem_flags::mem_threadgroup); + + mul_sv_op.run(tQKT, tV, ctO_0); + + if (D == 128) { + threadgroup_barrier(mem_flags::mem_none); + + auto tV_off = tV.slice(MUL_SV_BD, 0); + + mul_sv_op.run(tQKT, tV_off, ctO_1); + } + + K += BK * params->K_strides[2]; + V += BK * params->V_strides[2]; + } + + threadgroup_barrier(mem_flags::mem_none); + + if (align_Q || blockQ < params->NQ_aligned) { + // Write out result + #pragma clang loop unroll(full) + for (auto it = ctO_0.begin(); it != ctO_0.end(); it++) { + auto md_indices = it.get_multidimensional_index(); + + auto row_it = ct_row_sum.map_iterator(it); + + float softmax_denom = 1.0f / *row_it; + + O[(simd_group_id * ROWS_PER_SIMD + md_indices[1]) * params->O_strides[2] + md_indices[0]] = (T)(*it * softmax_denom); + } + + if (D == 128) { + #pragma clang loop unroll(full) + for (auto it = ctO_1.begin(); it != ctO_1.end(); it++) { + auto md_indices = it.get_multidimensional_index(); + + auto row_it = ct_row_sum.map_iterator(it); + + float softmax_denom = 1.0f / *row_it; + + O[(simd_group_id * ROWS_PER_SIMD + md_indices[1]) * params->O_strides[2] + md_indices[0] + MUL_SV_BD] = (T)(*it * softmax_denom); + } + } + } + else { + // Write out result + #pragma clang loop unroll(full) + for (auto it = ctO_0.begin(); it != ctO_0.end(); it++) { + auto md_indices = it.get_multidimensional_index(); + + short2 coords = short2(md_indices[0], md_indices[1] + simd_group_id * ROWS_PER_SIMD); + + auto row_it = ct_row_sum.map_iterator(it); + + float softmax_denom = 1.0f / *row_it; + + if (coords.y < params->qL_rem) + O[coords.y * params->O_strides[2] + coords.x] = (T)(*it * softmax_denom); + } + + if (D == 128) { + #pragma clang loop unroll(full) + for (auto it = ctO_1.begin(); it != ctO_1.end(); it++) { + auto md_indices = it.get_multidimensional_index(); + + short2 coords = short2(md_indices[0] + MUL_SV_BD, md_indices[1] + simd_group_id * ROWS_PER_SIMD); + + auto row_it = ct_row_sum.map_iterator(it); + + float softmax_denom = 1.0f / *row_it; + + if (coords.y < params->qL_rem) + O[coords.y * params->O_strides[2] + coords.x] = (T)(*it * softmax_denom); + } + } + } +} diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal index fee28fed10..8d3229f2b7 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal @@ -23,6 +23,23 @@ instantiate_attn_mask_helper(float16, half); instantiate_attn_mask_helper(bfloat16, bfloat16_t); - instantiate_attn_mask_helper(float32, float); + +#define instantiate_attn_tops(tname, dtype, bq, bk, bd, wm, wn, mname, mtype) \ + instantiate_kernel( \ + "steel_attention_tops_" #tname "_bq" #bq "_bk" #bk "_bd" #bd \ + "_wm" #wm "_wn" #wn "_mask" #mname, \ + attention_tops, dtype, bq, bk, bd, wm, wn, mtype, float) + +#define instantiate_attn_shapes_helper_tops(iname, itype, mname, mtype) \ + instantiate_attn_tops(iname, itype, 64, 32, 128, 4, 1, mname, mtype) \ + instantiate_attn_tops(iname, itype, 64, 32, 80, 4, 1, mname, mtype) \ + instantiate_attn_tops(iname, itype, 64, 32, 64, 4, 1, mname, mtype) + +#define instantiate_attn_mask_helper_tops(iname, itype) \ + instantiate_attn_shapes_helper_tops(iname, itype, iname, itype) \ + instantiate_attn_shapes_helper_tops(iname, itype, bool_, bool) + +instantiate_attn_mask_helper_tops(float16, half); +instantiate_attn_mask_helper_tops(float32, float); // clang-format on diff --git a/mlx/backend/metal/kernels/steel/gemm/loader.h b/mlx/backend/metal/kernels/steel/gemm/loader.h index 3f084d8ecd..306a1745f2 100644 --- a/mlx/backend/metal/kernels/steel/gemm/loader.h +++ b/mlx/backend/metal/kernels/steel/gemm/loader.h @@ -133,5 +133,126 @@ struct BlockLoader { } }; +template < + typename T, + typename TgType, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short alignment = 1, + short n_reads = (BCOLS * BROWS) / (tgp_size), + short TCOLS = BCOLS / n_reads, + short TROWS = tgp_size / TCOLS> +struct ConversionBlockLoader { + STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; + STEEL_CONST short vec_size = n_reads; + + // Leading dimension for src + const int src_ld; + const int tile_stride; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup TgType* dst; + const device T* src; + + /* Constructor */ + METAL_FUNC ConversionBlockLoader( + const device T* src_, + const int src_ld_, + threadgroup TgType* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * src_ld + bj) {} + + /* Apply operation to threadgroup without bound checking */ + template + METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]); + } + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = TgType(src[i * src_ld + j]); + } + } + } + + /* Load from device memory into threadgroup memory - with bound checking */ + METAL_FUNC void load_safe(short2 src_tile_dim) const { + src_tile_dim = src_tile_dim - short2(bj, bi); + + // Skip loading if thread has no valid reads + if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + return; + } + + // Use fast thread memory for bound checks + bool tmp_idx[vec_size]; + T tmp_val[vec_size]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + // Make sure tmp_idx only contains valid indices + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); + } + + // Read valid indices into tmp_val + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; + } + + // Zero out unneeded values + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); + } + + // Copy values to threadgroup memory + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = TgType(tmp_val[j]); + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + src += tile_stride; + } +}; + } // namespace steel } // namespace mlx diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 6f5807543e..a510389c03 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -417,8 +417,8 @@ void qmm( int wm = 2; int wn = 2; - int bm = 32; - int bn = 32; + int bm = (transpose) ? 64 : 32; + int bn = (transpose) ? 64 : 32; MTL::Size group_dims(32, wn, wm); MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, B); @@ -688,8 +688,8 @@ void gather_qmm_rhs( array biases = ensure_row_contiguous(biases_, d, s); // TODO: Tune the block sizes - int bm = 16, bn = 32, bk = 32; - int wm = 1, wn = 2; + int bm = 32, bn = 32, bk = 32; + int wm = 2, wn = 2; const bool align_M = (M % bm) == 0; const bool align_N = (N % bn) == 0; diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 845962d01d..1bd5fac561 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -28,10 +28,6 @@ void sdpa_full_self_attention_metal( int wm = 4; int wn = 1; - int bd = q.shape(-1); - int bq = 32; - int bk = bd < 128 ? 32 : 16; - int B = q.shape(0); int H = q.shape(1); int D = q.shape(3); @@ -40,11 +36,23 @@ void sdpa_full_self_attention_metal( int qL = q.shape(2); int kL = k.shape(2); - const bool align_Q = (qL % bq) == 0; - const bool align_K = (kL % bk) == 0; const bool has_mask = !!mask; const bool do_causal = do_causal_; + bool useTensorOps = q.dtype() != bfloat16 && !has_mask; + + int bd = q.shape(-1); + int bq = 32; + int bk = bd < 128 ? 32 : 16; + + if (useTensorOps) { + bq = 64; + bk = 32; + } + + const bool align_Q = (qL % bq) == 0; + const bool align_K = (kL % bk) == 0; + metal::MTLFCList func_consts = { {&align_Q, MTL::DataType::DataTypeBool, 200}, {&align_K, MTL::DataType::DataTypeBool, 201}, @@ -54,6 +62,7 @@ void sdpa_full_self_attention_metal( std::ostringstream kname; // clang-format off kname << "steel_attention_" + << (useTensorOps ? "tops_" : "") << type_to_name(q) << "_bq" << bq << "_bk" << bk