Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
409 changes: 216 additions & 193 deletions ggml/src/ggml-cuda.cu

Large diffs are not rendered by default.

564 changes: 564 additions & 0 deletions ggml/src/ggml-cuda/mmq_id.cu

Large diffs are not rendered by default.

12 changes: 12 additions & 0 deletions ggml/src/ggml-cuda/mmq_id.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#pragma once

#include "common.cuh"

void ggml_cuda_mul_mat_q_id(
ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids,
ggml_tensor * dst, char * ids_data, char * src1_quantized_data);

void compute_row_ids(const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds,
int64_t ne02, int64_t ne12, int64_t n_expert_used, int64_t ne11, int64_t nb11, int64_t nb12, int64_t nb21, cudaStream_t stream);

bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11);
4,183 changes: 4,183 additions & 0 deletions ggml/src/ggml-cuda/mmq_id_common.cuh

Large diffs are not rendered by default.

132 changes: 132 additions & 0 deletions ggml/src/ggml-cuda/quantize_id.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#include "quantize_id.cuh"
#include "mmq.cuh"
#include <cstdint>

template <mmq_q8_1_ds_layout ds_layout>
static __global__ void quantize_mmq_q8_1(
const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy,
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t ne0, const int ne1, const int ne2) {

constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;

const int64_t i0 = ((int64_t)blockDim.x*blockIdx.y + threadIdx.x)*4;

if (i0 >= ne0) {
return;
}

const int64_t i1 = blockIdx.x;
const int64_t i2 = blockIdx.z % ne2;
const int64_t i3 = blockIdx.z / ne2;

const int64_t i00 = i0;
const int64_t i01 = ids ? ids[i1] : i1;
const int64_t i02 = i2;
const int64_t i03 = i3;

const float4 * x4 = (const float4 *) x;

block_q8_1_mmq * y = (block_q8_1_mmq *) vy;

const int64_t ib0 = blockIdx.z*((int64_t)gridDim.x*gridDim.y*blockDim.x/QK8_1); // first block of channel
const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.x; // block index in channel
const int64_t iqs = i0 % (4*QK8_1); // quant index in block

// Load 4 floats per thread and calculate max. abs. value between them:
const float4 xi = i0 < ne00 ? x4[(i03*s03 + i02*s02 + i01*s01 + i00)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f);
float amax = fabsf(xi.x);
amax = fmaxf(amax, fabsf(xi.y));
amax = fmaxf(amax, fabsf(xi.z));
amax = fmaxf(amax, fabsf(xi.w));

// Exchange max. abs. value between vals_per_scale/4 threads.
#pragma unroll
for (int offset = vals_per_scale/8; offset > 0; offset >>= 1) {
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, offset, WARP_SIZE));
}

float sum;
if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) {
sum = xi.x + xi.y + xi.z + xi.w;

// Calculate sums across vals_per_sum/4 threads.
#pragma unroll
for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) {
sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE);
}
}

const float d_inv = 127.0f / amax;
char4 q;
q.x = roundf(xi.x*d_inv);
q.y = roundf(xi.y*d_inv);
q.z = roundf(xi.z*d_inv);
q.w = roundf(xi.w*d_inv);

// Write back 4 int8 values as a single 32 bit value for better memroy bandwidth:
char4 * yqs4 = (char4 *) y[ib].qs;
yqs4[iqs/4] = q;

if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6) {
if (iqs % 16 != 0 || iqs >= 96) {
return;
}

y[ib].d2s6[2 + iqs/16] = sum;

if (iqs % 64 != 0) {
return;
}

const float d = 1.0f / d_inv;

y[ib].d2s6[iqs/64] = d;

return;
}

if (iqs % 32 != 0) {
return;
}

const float d = 1.0f / d_inv;

if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) {
y[ib].ds4[iqs/32] = make_half2(d, sum);
} else {
y[ib].d4[iqs/32] = d;
}
}

void quantize_mmq_q8_1_cuda_id(
const float * x, const int32_t * ids, void * vy, const ggml_type type_src0,
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
GGML_ASSERT(ne00 % 4 == 0);
GGML_ASSERT(ne0 % (4*QK8_1) == 0);
GGML_ASSERT(ids);

// ne1 tends to assume the highest values, therefore use it as the "x" dimension of the CUDA grid:
const int64_t block_num_y = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
const dim3 num_blocks(ne1, block_num_y, ne2*ne3);
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
switch (mmq_get_q8_1_ds_layout(type_src0)) {
case MMQ_Q8_1_DS_LAYOUT_D4:
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>
<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
break;
case MMQ_Q8_1_DS_LAYOUT_DS4:
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4>
<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
break;
case MMQ_Q8_1_DS_LAYOUT_D2S6:
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6>
<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
break;
default:
GGML_ABORT("fatal error");
break;
}
}
16 changes: 16 additions & 0 deletions ggml/src/ggml-cuda/quantize_id.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#pragma once

#include "common.cuh"

#include <cstdint>

#define CUDA_QUANTIZE_BLOCK_SIZE 256
#define CUDA_QUANTIZE_BLOCK_SIZE_MMQ 128

//static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk of out-of-bounds access.");
//static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");

void quantize_mmq_q8_1_cuda_id(
const float * x, const int32_t * ids, void * vy,
ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
83 changes: 83 additions & 0 deletions ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_kt_id.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.

#include "../mmq_id_common.cuh"

template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_kt(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {

constexpr int nwarps = mmq_get_nwarps_device();

constexpr uint32_t ka = 0xCBAC1FED;
constexpr uint32_t km = 0x3f3f3f3f;

#ifdef INT8_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + WARP_SIZE*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE

const int kqsx = threadIdx.x;

#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + threadIdx.y;

if (need_check) {
i = min(i, i_max);
}

const block_iq1_kt * bxi = (const block_iq1_kt *)(x + i*stride + sizeof(float)) + kbx0;

int ib32 = kqsx/4;
int j = kqsx%4;
uint32_t val = bxi->ql[kqsx] + ((bxi->qh[kqsx%16] << (8 - 4*(kqsx/16))) & 0xf00) + ((bxi->sh[kqsx/4] << (8 - (kqsx%4))) & 0x1000) + 4096;
int2 v = {0, 0};
for (int k = 0; k < 4; ++k) {
val *= ka;
v.x |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k;
}
for (int k = 0; k < 4; ++k) {
val *= ka;
v.y |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k;
}
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 0] = v.x;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 1] = v.y;
#else
x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 0] = v.x;
x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 1] = v.y;
#endif // INT8_MMA_AVAILABLE
}

#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);

if (need_check) {
i = min(i, i_max);
}

const float * dptr = (const float *)(x + i*stride);
const float d = dptr[0];
const block_iq1_kt * bxi = (const block_iq1_kt *)(dptr + 1) + kbx0;
const int ls = iq4k_values[bxi->sh[threadIdx.x % 8] & 0xf];

#ifdef INT8_MMA_AVAILABLE
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * ls;
#else
x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * ls;
#endif // INT8_MMA_AVAILABLE
}
}

template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ1_KT> {
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_kt<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
};

DECL_MMQ_CASE(GGML_TYPE_IQ1_KT);
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.

#include "../mmq_id_common.cuh"

DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
DECL_MMQ_CASE(GGML_TYPE_IQ1_S_R4);
Loading