Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 141 additions & 12 deletions ggml/src/ggml-cuda/norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,51 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
}
}

template <int block_size>
static __global__ void rms_norm_f32_nc(
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
const int64_t stride_sample, const float eps) {
const int nrows = gridDim.x;
const int nchannels = gridDim.y;

const int row = blockIdx.x;
const int channel = blockIdx.y;
const int sample = blockIdx.z;
const int tid = threadIdx.x;

x += sample*stride_sample + channel*stride_channel + row*stride_row;
dst += ((sample*nchannels + channel)*nrows + row)*ncols;

float tmp = 0.0f; // partial sum for thread in warp

for (int col = tid; col < ncols; col += block_size) {
const float xi = x[col];
tmp += xi * xi;
}

// sum up partial sums
tmp = warp_reduce_sum(tmp);
if constexpr (block_size > WARP_SIZE) {
static_assert(block_size == 1024, "unexpected block_size");
__shared__ float s_sum[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = s_sum[lane_id];
tmp = warp_reduce_sum(tmp);
}

const float mean = tmp / ncols;
const float scale = rsqrtf(mean + eps);

for (int col = tid; col < ncols; col += block_size) {
dst[col] = scale * x[col];
}
}

template <int block_size>
static __global__ void fused_rms_norm_f32(const float * x, const float * y, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
Expand Down Expand Up @@ -165,6 +210,51 @@ static __global__ void fused_rms_norm_f32(const float * x, const float * y, floa
}
}

template <int block_size>
static __global__ void fused_rms_norm_f32_nc(
const float * x, const float * y, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
const int64_t stride_sample, const float eps) {
const int nrows = gridDim.x;
const int nchannels = gridDim.y;

const int row = blockIdx.x;
const int channel = blockIdx.y;
const int sample = blockIdx.z;
const int tid = threadIdx.x;

x += sample*stride_sample + channel*stride_channel + row*stride_row;
dst += ((sample*nchannels + channel)*nrows + row)*ncols;

float tmp = 0.0f; // partial sum for thread in warp

for (int col = tid; col < ncols; col += block_size) {
const float xi = x[col];
tmp += xi * xi;
}

// sum up partial sums
tmp = warp_reduce_sum(tmp);
if constexpr (block_size > WARP_SIZE) {
static_assert(block_size == 1024, "unexpected block_size");
__shared__ float s_sum[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = s_sum[lane_id];
tmp = warp_reduce_sum(tmp);
}

const float mean = tmp / ncols;
const float scale = rsqrtf(mean + eps);

for (int col = tid; col < ncols; col += block_size) {
dst[col] = scale * y[col] * x[col];
}
}

static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < 1024) {
Expand Down Expand Up @@ -197,6 +287,19 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
}
}

static void rms_norm_f32_nc_cuda(
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32_nc<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32_nc<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}

static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * dst,
const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
Expand All @@ -209,6 +312,19 @@ static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * ds
}
}

static void fused_rms_norm_f32_nc_cuda(
const float * x, const float * y, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
fused_rms_norm_f32_nc<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
fused_rms_norm_f32_nc<1024><<<blocks_num, block_dims, 0, stream>>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}

void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
Expand Down Expand Up @@ -255,18 +371,24 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();

GGML_ASSERT(ggml_is_contiguous(src0));

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);

const int64_t ne00 = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);

float eps;
memcpy(&eps, dst->op_params, sizeof(float));

rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
const int64_t ne00 = src0->ne[0];
if (ggml_is_contiguous(src0)) {
const int64_t nrows = ggml_nrows(src0);
rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
} else {
auto ts0 = ggml_type_size(src0->type);
GGML_ASSERT(src0->nb[0] == ts0);
auto s01 = src0->nb[1] / ts0;
auto s02 = src0->nb[2] / ts0;
auto s03 = src0->nb[3] / ts0;
rms_norm_f32_nc_cuda(src0_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream);
}
}

void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
Expand All @@ -281,19 +403,26 @@ void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor *
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();

GGML_ASSERT(ggml_is_contiguous(src0));

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->ne[0] == src1->ne[0]);
GGML_ASSERT(ggml_nrows(src1) == 1);

const int64_t ne00 = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);

float eps;
memcpy(&eps, dst->op_params, sizeof(float));

fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, stream);
const int64_t ne00 = src0->ne[0];

if (ggml_is_contiguous(src0)) {
const int64_t nrows = ggml_nrows(src0);
fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, stream);
} else {
auto ts0 = ggml_type_size(src0->type);
GGML_ASSERT(src0->nb[0] == ts0);
auto s01 = src0->nb[1] / ts0;
auto s02 = src0->nb[2] / ts0;
auto s03 = src0->nb[3] / ts0;
fused_rms_norm_f32_nc_cuda(src0_d, src1_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream);
}
}
6 changes: 3 additions & 3 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13390,7 +13390,7 @@ struct llm_build_context {
ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
cb(k_pe, "k_pe", il);

kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
//kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
model.layers[il].attn_kv_a_norm, NULL,
LLM_NORM_RMS, cb, il);
Expand Down Expand Up @@ -13422,7 +13422,7 @@ struct llm_build_context {
0);
cb(v_states, "v_states", il);

q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
//q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
q_pe = ggml_rope_ext(
ctx0, q_pe, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
Expand All @@ -13431,7 +13431,7 @@ struct llm_build_context {
cb(q_pe, "q_pe", il);

// shared RoPE key
k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
//k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
k_pe = ggml_rope_ext(
ctx0, k_pe, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
Expand Down