Skip to content

Commit 9a6a6ef

Browse files
committed
CUDA/HIP: add support for selectable warp size to mmv
1 parent 7919256 commit 9a6a6ef

File tree

4 files changed

+28
-11
lines changed

4 files changed

+28
-11
lines changed

ggml/src/ggml-cuda/mmv.cu

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ static __global__ void mul_mat_vec(
1818
extern __shared__ char data_mmv[];
1919
float * buf_iw = (float *) data_mmv;
2020

21-
if (block_size > WARP_SIZE) {
22-
if (tid < WARP_SIZE) {
21+
if (block_size > GGML_TRUE_WARP_SIZE) {
22+
if (tid < GGML_TRUE_WARP_SIZE) {
2323
buf_iw[tid] = 0.0f;
2424
}
2525
__syncthreads();
@@ -67,16 +67,16 @@ static __global__ void mul_mat_vec(
6767
static_assert(std::is_same<T, void>::value, "unsupported type");
6868
}
6969

70-
sumf = warp_reduce_sum(sumf);
70+
sumf = warp_reduce_sum<GGML_TRUE_WARP_SIZE>(sumf);
7171

72-
if (block_size > WARP_SIZE) {
73-
buf_iw[tid/WARP_SIZE] = sumf;
72+
if (block_size > GGML_TRUE_WARP_SIZE) {
73+
buf_iw[tid/GGML_TRUE_WARP_SIZE] = sumf;
7474
__syncthreads();
75-
if (tid >= WARP_SIZE) {
75+
if (tid >= GGML_TRUE_WARP_SIZE) {
7676
return;
7777
}
7878
sumf = buf_iw[tid];
79-
sumf = warp_reduce_sum(sumf);
79+
sumf = warp_reduce_sum<GGML_TRUE_WARP_SIZE>(sumf);
8080
}
8181

8282
if (tid != 0) {
@@ -96,18 +96,27 @@ static void launch_mul_mat_vec_cuda(
9696
GGML_ASSERT(stride_row % 2 == 0);
9797
GGML_ASSERT(nchannels_y % nchannels_x == 0);
9898
const int64_t channel_ratio = nchannels_y / nchannels_x;
99+
int device;
100+
int warp_size;
99101

100-
int64_t block_size_best = WARP_SIZE;
101-
int64_t niter_best = (ncols + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
102-
for (int64_t block_size = 2*WARP_SIZE; block_size <= 256; block_size += WARP_SIZE) {
102+
CUDA_CHECK(cudaGetDevice(&device));
103+
warp_size = ggml_cuda_info().devices[device].warp_size;
104+
105+
int64_t block_size_best = warp_size;
106+
int64_t niter_best = (ncols + 2*warp_size - 1) / (2*warp_size);
107+
int64_t max_block_size = 256;
108+
if(ggml_cuda_info().devices[device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info().devices[device].cc < GGML_CUDA_CC_RDNA1) {
109+
max_block_size = 128;
110+
}
111+
for (int64_t block_size = 2*warp_size; block_size <= 128; block_size += warp_size) {
103112
const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size);
104113
if (niter < niter_best) {
105114
niter_best = niter;
106115
block_size_best = block_size;
107116
}
108117
}
109118

110-
const int smem = WARP_SIZE*sizeof(float);
119+
const int smem = warp_size*sizeof(float);
111120
const dim3 block_nums(nrows, 1, nchannels_y);
112121
const dim3 block_dims(block_size_best, 1, 1);
113122
switch (block_size_best) {

ggml/src/ggml-cuda/vendors/cuda.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,5 @@
1313
#define CUBLAS_COMPUTE_32F CUDA_R_32F
1414
#define cublasComputeType_t cudaDataType_t
1515
#endif // CUDART_VERSION < 11020
16+
17+
#define GGML_TRUE_WARP_SIZE 32 // Only use this in device code

ggml/src/ggml-cuda/vendors/hip.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#define HIP_ENABLE_WARP_SYNC_BUILTINS 1
34
#include <hip/hip_runtime.h>
45
#include <hipblas/hipblas.h>
56
#include <hip/hip_fp16.h>
@@ -8,6 +9,7 @@
89
// for rocblas_initialize()
910
#include "rocblas/rocblas.h"
1011
#endif // __HIP_PLATFORM_AMD__
12+
1113
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
1214
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
1315
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
@@ -137,6 +139,8 @@
137139
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
138140
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
139141

142+
#define GGML_TRUE_WARP_SIZE __AMDGCN_WAVEFRONT_SIZE__ // Only use this in device code
143+
140144
#define __CUDA_ARCH__ 1300
141145

142146
#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)

ggml/src/ggml-cuda/vendors/musa.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,5 @@
135135
#define cudaStreamEndCapture musaStreamEndCapture
136136

137137
typedef mt_bfloat16 nv_bfloat16;
138+
139+
#define GGML_TRUE_WARP_SIZE 32 // Only use this in device code

0 commit comments

Comments
 (0)