Skip to content

Commit 95c8b9c

Browse files
committed
Vectorize load instructions in dmmv f16 CUDA kernel
Replaces scalar with vector load instructions, which substantially improves performance on NVIDIA HBM GPUs, e.g. gives a 1.27X overall speedup for Meta-Llama-3-8B-Instruct-F16 BS1 inference evaluation on H100 SXM 80GB HBM3. On GDDR GPUs, there is a slight (1.01X) speedup.
1 parent e702206 commit 95c8b9c

File tree

1 file changed

+28
-9
lines changed

1 file changed

+28
-9
lines changed

ggml/src/ggml-cuda/dmmv.cu

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -416,10 +416,11 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
416416

417417
static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
418418
const half * x = (const half *) vx;
419-
419+
// load 2 halfs into register in a single instruction
420+
const half2 x_reg = *((half2 *) &(x[ib + iqs]));
420421
// automatic half -> float type cast if dfloat == float
421-
v.x = x[ib + iqs + 0];
422-
v.y = x[ib + iqs + 1];
422+
v.x = x_reg.x;
423+
v.y = x_reg.y;
423424
}
424425

425426
static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type type) {
@@ -476,13 +477,31 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
476477
// matrix multiplication
477478
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
478479
#ifdef GGML_CUDA_F16
479-
tmp += __hmul2(v, {
480-
y[iybs + iqs + j/qr + 0],
481-
y[iybs + iqs + j/qr + y_offset]
482-
});
480+
if ( y_offset == 1 ) {
481+
// load 2 dfloats into register in a single instruction
482+
const dfloat2 y_reg = *((dfloat2 *) &(y[iybs + iqs + j/qr]));
483+
tmp += __hmul2(v, {
484+
y_reg.x;
485+
y_reg.y;
486+
});
487+
}
488+
else {
489+
tmp += __hmul2(v, {
490+
y[iybs + iqs + j/qr + 0],
491+
y[iybs + iqs + j/qr + y_offset]
492+
});
493+
}
483494
#else
484-
tmp += v.x * y[iybs + iqs + j/qr + 0];
485-
tmp += v.y * y[iybs + iqs + j/qr + y_offset];
495+
if ( y_offset == 1 ) {
496+
// load 2 dfloats into register in a single instruction
497+
const dfloat2 y_reg = *((dfloat2 *) &(y[iybs + iqs + j/qr]));
498+
tmp += v.x * y_reg.x;
499+
tmp += v.y * y_reg.y;
500+
}
501+
else {
502+
tmp += v.x * y[iybs + iqs + j/qr + 0];
503+
tmp += v.y * y[iybs + iqs + j/qr + y_offset];
504+
}
486505
#endif // GGML_CUDA_F16
487506
}
488507
}

0 commit comments

Comments
 (0)