Skip to content
Merged
27 changes: 27 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,33 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
#endif // CUDART_VERSION >= 12050
}

// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
// Precompute mp (m' in the paper) and L such that division
// can be computed using a multiply (high 32b of 64b result)
// and a shift:
//
// n/d = (mulhi(n, mp) + n) >> L;
static void init_fastdiv_values(uint32_t d, uint32_t & mp, uint32_t & L) {
// compute L = ceil(log2(d));
L = 0;
while (L < 32 && (uint32_t{ 1 } << L) < d) {
L++;
}

mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1);
}

static __device__ __forceinline__ uint32_t fastdiv(uint32_t n, uint32_t mp, uint32_t L) {
// Compute high 32 bits of n * mp
uint32_t hi = __umulhi(n, mp);
// Apply the formula
return (hi + n) >> L;
}

static __device__ __forceinline__ uint32_t fastmodulo(uint32_t n, uint32_t divisor, int mp, uint32_t L) {
return n - fastdiv(n, mp, L) * divisor;
}

typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, float2 & v);

static __device__ __forceinline__ float get_alibi_slope(
Expand Down
226 changes: 136 additions & 90 deletions ggml/src/ggml-cuda/norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -105,29 +105,45 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
}

template <int block_size, bool do_multiply = false, bool do_add = false>
static __global__ void rms_norm_f32(const float * x, float * dst,
const int ncols,
const int64_t stride_row,
const int64_t stride_channel,
const int64_t stride_sample,
const float eps,
const float * mul = nullptr,
const int64_t mul_stride_row = 0,
const int64_t mul_stride_channel = 0,
const int64_t mul_stride_sample = 0,
const int mul_ncols = 0,
const int mul_nrows = 0,
const int mul_nchannels = 0,
const int mul_nsamples = 0,
const float * add = nullptr,
const int64_t add_stride_row = 0,
const int64_t add_stride_channel = 0,
const int64_t add_stride_sample = 0,
const int add_ncols = 0,
const int add_nrows = 0,
const int add_nchannels = 0,
const int add_nsamples = 0) {

static __global__ void rms_norm_f32(const float * x,
float * dst,
const int ncols,
const int64_t stride_row,
const int64_t stride_channel,
const int64_t stride_sample,
const float eps,
const float * mul = nullptr,
const int64_t mul_stride_row = 0,
const int64_t mul_stride_channel = 0,
const int64_t mul_stride_sample = 0,
const uint32_t mul_ncols = 0,
const uint32_t mul_nrows = 0,
const uint32_t mul_nchannels = 0,
const uint32_t mul_nsamples = 0,
const uint32_t mp_mul_cols = 0,
const uint32_t L_mul_cols = 0,
const uint32_t mp_mul_rows = 0,
const uint32_t L_mul_rows = 0,
const uint32_t mp_mul_channels = 0,
const uint32_t L_mul_channels = 0,
const uint32_t mp_mul_samples = 0,
const uint32_t L_mul_samples = 0,
const float * add = nullptr,
const int64_t add_stride_row = 0,
const int64_t add_stride_channel = 0,
const int64_t add_stride_sample = 0,
const uint32_t add_ncols = 0,
const uint32_t add_nrows = 0,
const uint32_t add_nchannels = 0,
const uint32_t add_nsamples = 0,
const uint32_t mp_add_cols = 0,
const uint32_t L_add_cols = 0,
const uint32_t mp_add_rows = 0,
const uint32_t L_add_rows = 0,
const uint32_t mp_add_channels = 0,
const uint32_t L_add_channels = 0,
const uint32_t mp_add_samples = 0,
const uint32_t L_add_samples = 0) {
const int nrows = gridDim.x;
const int nchannels = gridDim.y;

Expand All @@ -142,16 +158,16 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
dst += ((sample*nchannels + channel)*nrows + row)*ncols;

if constexpr (do_multiply) {
const int mul_row = row % mul_nrows;
const int mul_channel = channel % mul_nchannels;
const int mul_sample = sample % mul_nsamples;
mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row;
const uint32_t mul_row = fastmodulo(row, mul_nrows, mp_mul_rows, L_mul_rows);
const uint32_t mul_channel = fastmodulo(channel, mul_nchannels, mp_mul_channels, L_mul_channels);
const uint32_t mul_sample = fastmodulo(sample, mul_nsamples, mp_mul_samples, L_mul_samples);
mul += mul_sample * mul_stride_sample + mul_channel * mul_stride_channel + mul_row * mul_stride_row;
}

if constexpr (do_add) {
const int add_row = row % add_nrows;
const int add_channel = channel % add_nchannels;
const int add_sample = sample % add_nsamples;
const int add_row = fastmodulo(row, add_nrows, mp_add_rows, L_add_rows);
const int add_channel = fastmodulo(channel, add_nchannels, mp_add_channels, L_add_channels);
const int add_sample = fastmodulo(sample, add_nsamples, mp_add_samples, L_add_samples);
add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row;
}

Expand All @@ -165,15 +181,18 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
// sum up partial sums
tmp = warp_reduce_sum(tmp);
if constexpr (block_size > WARP_SIZE) {
static_assert(block_size == 1024, "unexpected block_size");
static_assert((block_size <= 1024) && (block_size % 32 == 0), "unexpected block_size");
__shared__ float s_sum[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
const int warp_id = tid / WARP_SIZE;
const int lane_id = tid % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = s_sum[lane_id];
tmp = 0.0f;
if (lane_id < (block_size / WARP_SIZE)) {
tmp = s_sum[lane_id];
}
tmp = warp_reduce_sum(tmp);
}

Expand All @@ -182,12 +201,12 @@ static __global__ void rms_norm_f32(const float * x, float * dst,

for (int col = tid; col < ncols; col += block_size) {
if constexpr (do_multiply && do_add) {
const int mul_col = col % mul_ncols;
const int add_col = col % add_ncols;
dst[col] = scale * x[col] * mul[mul_col] + add[add_col];
const int mul_col = fastmodulo(col, mul_ncols, mp_mul_cols, L_mul_cols);
const int add_col = fastmodulo(col, add_ncols, mp_add_cols, L_add_cols);
dst[col] = scale * x[col] * mul[mul_col] + add[add_col];
} else if constexpr (do_multiply) {
const int mul_col = col % mul_ncols;
dst[col] = scale * x[col] * mul[mul_col];
const int mul_col = fastmodulo(col, mul_ncols, mp_mul_cols, L_mul_cols);
dst[col] = scale * x[col] * mul[mul_col];
} else {
dst[col] = scale * x[col];
}
Expand Down Expand Up @@ -354,77 +373,104 @@ static void rms_norm_f32_cuda(
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32<WARP_SIZE, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
const dim3 block_dims(256, 1, 1);
rms_norm_f32<256, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}

static void rms_norm_mul_f32_cuda(const float * x,
const float * mul,
const float * add,
float * dst,
const int ncols,
const int nrows,
const int nchannels,
const int nsamples,
const int64_t stride_row,
const int64_t stride_channel,
const int64_t stride_sample,
const int64_t mul_stride_row,
const int64_t mul_stride_channel,
const int64_t mul_stride_sample,
const int mul_ncols,
const int mul_nrows,
const int mul_nchannels,
const int mul_nsamples,
const int64_t add_stride_row,
const int64_t add_stride_channel,
const int64_t add_stride_sample,
const int add_ncols,
const int add_nrows,
const int add_nchannels,
const int add_nsamples,
const float eps,
cudaStream_t stream) {
static void rms_norm_mul_f32_cuda(const float * x,
const float * mul,
const float * add,
float * dst,
const int ncols,
const int nrows,
const int nchannels,
const int nsamples,
const int64_t stride_row,
const int64_t stride_channel,
const int64_t stride_sample,
const int64_t mul_stride_row,
const int64_t mul_stride_channel,
const int64_t mul_stride_sample,
const uint32_t mul_ncols,
const uint32_t mul_nrows,
const uint32_t mul_nchannels,
const uint32_t mul_nsamples,
const int64_t add_stride_row,
const int64_t add_stride_channel,
const int64_t add_stride_sample,
const uint32_t add_ncols,
const uint32_t add_nrows,
const uint32_t add_nchannels,
const uint32_t add_nsamples,
const float eps,
cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
if (mul == nullptr) {
rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
return;
}
if (add == nullptr) {
uint32_t mp_mul_cols, L_mul_cols;
init_fastdiv_values(mul_ncols, mp_mul_cols, L_mul_cols);
uint32_t mp_mul_rows, L_mul_rows;
init_fastdiv_values(mul_nrows, mp_mul_rows, L_mul_rows);
uint32_t mp_mul_channels, L_mul_channels;
init_fastdiv_values(mul_nchannels, mp_mul_channels, L_mul_channels);
uint32_t mp_mul_samples, L_mul_samples;
init_fastdiv_values(mul_nsamples, mp_mul_samples, L_mul_samples);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32<WARP_SIZE, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
ncols, stride_row, stride_channel, stride_sample, eps,
mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
const dim3 block_dims(256, 1, 1);
rms_norm_f32<256, true><<<blocks_num, block_dims, 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, mp_mul_cols, L_mul_cols,
mp_mul_rows, L_mul_rows, mp_mul_channels, L_mul_channels, mp_mul_samples, L_mul_samples);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
ncols, stride_row, stride_channel, stride_sample, eps,
mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, mp_mul_cols, L_mul_cols,
mp_mul_rows, L_mul_rows, mp_mul_channels, L_mul_channels, mp_mul_samples, L_mul_samples);
}
} else {
uint32_t mp_mul_cols, L_mul_cols;
init_fastdiv_values(mul_ncols, mp_mul_cols, L_mul_cols);
uint32_t mp_mul_rows, L_mul_rows;
init_fastdiv_values(mul_nrows, mp_mul_rows, L_mul_rows);
uint32_t mp_mul_channels, L_mul_channels;
init_fastdiv_values(mul_nchannels, mp_mul_channels, L_mul_channels);
uint32_t mp_mul_samples, L_mul_samples;
init_fastdiv_values(mul_nsamples, mp_mul_samples, L_mul_samples);

uint32_t mp_add_cols, L_add_cols;
init_fastdiv_values(add_ncols, mp_add_cols, L_add_cols);
uint32_t mp_add_rows, L_add_rows;
init_fastdiv_values(add_nrows, mp_add_rows, L_add_rows);
uint32_t mp_add_channels, L_add_channels;
init_fastdiv_values(add_nchannels, mp_add_channels, L_add_channels);
uint32_t mp_add_samples, L_add_samples;
init_fastdiv_values(add_nsamples, mp_add_samples, L_add_samples);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32<WARP_SIZE, true, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
ncols, stride_row, stride_channel, stride_sample, eps,
mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
add, add_stride_row, add_stride_channel, add_stride_sample,
add_ncols, add_nrows, add_nchannels, add_nsamples);
const dim3 block_dims(256, 1, 1);
rms_norm_f32<256, true, true><<<blocks_num, block_dims, 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, mp_mul_cols, L_mul_cols,
mp_mul_rows, L_mul_rows, mp_mul_channels, L_mul_channels, mp_mul_samples, L_mul_samples, add,
add_stride_row, add_stride_channel, add_stride_sample, add_ncols, add_nrows, add_nchannels,
add_nsamples, mp_add_cols, L_add_cols, mp_add_rows, L_add_rows, mp_add_channels, L_add_channels,
mp_add_samples, L_add_samples);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024, true, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
ncols, stride_row, stride_channel, stride_sample, eps,
mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
add, add_stride_row, add_stride_channel, add_stride_sample,
add_ncols, add_nrows, add_nchannels, add_nsamples);
rms_norm_f32<1024, true, true><<<blocks_num, block_dims, 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, mp_mul_cols, L_mul_cols,
mp_mul_rows, L_mul_rows, mp_mul_channels, L_mul_channels, mp_mul_samples, L_mul_samples, add,
add_stride_row, add_stride_channel, add_stride_sample, add_ncols, add_nrows, add_nchannels,
add_nsamples, mp_add_cols, L_add_cols, mp_add_rows, L_add_rows, mp_add_channels, L_add_channels,
mp_add_samples, L_add_samples);
}
}
}
Expand Down
Loading