Skip to content

Commit 4601a8c

Browse files
ikawrakowIwan Kawrakow
andauthored
cuda: non-contiguous rms norm (#190)
Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent b08a2e9 commit 4601a8c

File tree

2 files changed

+144
-15
lines changed

2 files changed

+144
-15
lines changed

ggml/src/ggml-cuda/norm.cu

Lines changed: 141 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,51 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
131131
}
132132
}
133133

134+
template <int block_size>
135+
static __global__ void rms_norm_f32_nc(
136+
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
137+
const int64_t stride_sample, const float eps) {
138+
const int nrows = gridDim.x;
139+
const int nchannels = gridDim.y;
140+
141+
const int row = blockIdx.x;
142+
const int channel = blockIdx.y;
143+
const int sample = blockIdx.z;
144+
const int tid = threadIdx.x;
145+
146+
x += sample*stride_sample + channel*stride_channel + row*stride_row;
147+
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
148+
149+
float tmp = 0.0f; // partial sum for thread in warp
150+
151+
for (int col = tid; col < ncols; col += block_size) {
152+
const float xi = x[col];
153+
tmp += xi * xi;
154+
}
155+
156+
// sum up partial sums
157+
tmp = warp_reduce_sum(tmp);
158+
if constexpr (block_size > WARP_SIZE) {
159+
static_assert(block_size == 1024, "unexpected block_size");
160+
__shared__ float s_sum[32];
161+
const int warp_id = threadIdx.x / WARP_SIZE;
162+
const int lane_id = threadIdx.x % WARP_SIZE;
163+
if (lane_id == 0) {
164+
s_sum[warp_id] = tmp;
165+
}
166+
__syncthreads();
167+
tmp = s_sum[lane_id];
168+
tmp = warp_reduce_sum(tmp);
169+
}
170+
171+
const float mean = tmp / ncols;
172+
const float scale = rsqrtf(mean + eps);
173+
174+
for (int col = tid; col < ncols; col += block_size) {
175+
dst[col] = scale * x[col];
176+
}
177+
}
178+
134179
template <int block_size>
135180
static __global__ void fused_rms_norm_f32(const float * x, const float * y, float * dst, const int ncols, const float eps) {
136181
const int row = blockIdx.x*blockDim.y + threadIdx.y;
@@ -165,6 +210,51 @@ static __global__ void fused_rms_norm_f32(const float * x, const float * y, floa
165210
}
166211
}
167212

213+
template <int block_size>
214+
static __global__ void fused_rms_norm_f32_nc(
215+
const float * x, const float * y, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
216+
const int64_t stride_sample, const float eps) {
217+
const int nrows = gridDim.x;
218+
const int nchannels = gridDim.y;
219+
220+
const int row = blockIdx.x;
221+
const int channel = blockIdx.y;
222+
const int sample = blockIdx.z;
223+
const int tid = threadIdx.x;
224+
225+
x += sample*stride_sample + channel*stride_channel + row*stride_row;
226+
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
227+
228+
float tmp = 0.0f; // partial sum for thread in warp
229+
230+
for (int col = tid; col < ncols; col += block_size) {
231+
const float xi = x[col];
232+
tmp += xi * xi;
233+
}
234+
235+
// sum up partial sums
236+
tmp = warp_reduce_sum(tmp);
237+
if constexpr (block_size > WARP_SIZE) {
238+
static_assert(block_size == 1024, "unexpected block_size");
239+
__shared__ float s_sum[32];
240+
const int warp_id = threadIdx.x / WARP_SIZE;
241+
const int lane_id = threadIdx.x % WARP_SIZE;
242+
if (lane_id == 0) {
243+
s_sum[warp_id] = tmp;
244+
}
245+
__syncthreads();
246+
tmp = s_sum[lane_id];
247+
tmp = warp_reduce_sum(tmp);
248+
}
249+
250+
const float mean = tmp / ncols;
251+
const float scale = rsqrtf(mean + eps);
252+
253+
for (int col = tid; col < ncols; col += block_size) {
254+
dst[col] = scale * y[col] * x[col];
255+
}
256+
}
257+
168258
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
169259
GGML_ASSERT(ncols % WARP_SIZE == 0);
170260
if (ncols < 1024) {
@@ -197,6 +287,19 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
197287
}
198288
}
199289

290+
static void rms_norm_f32_nc_cuda(
291+
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
292+
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
293+
const dim3 blocks_num(nrows, nchannels, nsamples);
294+
if (ncols < 1024) {
295+
const dim3 block_dims(WARP_SIZE, 1, 1);
296+
rms_norm_f32_nc<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
297+
} else {
298+
const dim3 block_dims(1024, 1, 1);
299+
rms_norm_f32_nc<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
300+
}
301+
}
302+
200303
static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * dst,
201304
const int ncols, const int nrows, const float eps, cudaStream_t stream) {
202305
GGML_ASSERT(ncols % WARP_SIZE == 0);
@@ -209,6 +312,19 @@ static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * ds
209312
}
210313
}
211314

315+
static void fused_rms_norm_f32_nc_cuda(
316+
const float * x, const float * y, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
317+
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
318+
const dim3 blocks_num(nrows, nchannels, nsamples);
319+
if (ncols < 1024) {
320+
const dim3 block_dims(WARP_SIZE, 1, 1);
321+
fused_rms_norm_f32_nc<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps);
322+
} else {
323+
const dim3 block_dims(1024, 1, 1);
324+
fused_rms_norm_f32_nc<1024><<<blocks_num, block_dims, 0, stream>>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps);
325+
}
326+
}
327+
212328
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
213329
const ggml_tensor * src0 = dst->src[0];
214330
const float * src0_d = (const float *)src0->data;
@@ -255,18 +371,24 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
255371
float * dst_d = (float *)dst->data;
256372
cudaStream_t stream = ctx.stream();
257373

258-
GGML_ASSERT(ggml_is_contiguous(src0));
259-
260374
GGML_ASSERT(src0->type == GGML_TYPE_F32);
261375
GGML_ASSERT( dst->type == GGML_TYPE_F32);
262376

263-
const int64_t ne00 = src0->ne[0];
264-
const int64_t nrows = ggml_nrows(src0);
265-
266377
float eps;
267378
memcpy(&eps, dst->op_params, sizeof(float));
268379

269-
rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
380+
const int64_t ne00 = src0->ne[0];
381+
if (ggml_is_contiguous(src0)) {
382+
const int64_t nrows = ggml_nrows(src0);
383+
rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
384+
} else {
385+
auto ts0 = ggml_type_size(src0->type);
386+
GGML_ASSERT(src0->nb[0] == ts0);
387+
auto s01 = src0->nb[1] / ts0;
388+
auto s02 = src0->nb[2] / ts0;
389+
auto s03 = src0->nb[3] / ts0;
390+
rms_norm_f32_nc_cuda(src0_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream);
391+
}
270392
}
271393

272394
void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -281,19 +403,26 @@ void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor *
281403
float * dst_d = (float *)dst->data;
282404
cudaStream_t stream = ctx.stream();
283405

284-
GGML_ASSERT(ggml_is_contiguous(src0));
285-
286406
GGML_ASSERT(src0->type == GGML_TYPE_F32);
287407
GGML_ASSERT(src1->type == GGML_TYPE_F32);
288408
GGML_ASSERT( dst->type == GGML_TYPE_F32);
289409
GGML_ASSERT(src0->ne[0] == src1->ne[0]);
290410
GGML_ASSERT(ggml_nrows(src1) == 1);
291411

292-
const int64_t ne00 = src0->ne[0];
293-
const int64_t nrows = ggml_nrows(src0);
294-
295412
float eps;
296413
memcpy(&eps, dst->op_params, sizeof(float));
297414

298-
fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, stream);
415+
const int64_t ne00 = src0->ne[0];
416+
417+
if (ggml_is_contiguous(src0)) {
418+
const int64_t nrows = ggml_nrows(src0);
419+
fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, stream);
420+
} else {
421+
auto ts0 = ggml_type_size(src0->type);
422+
GGML_ASSERT(src0->nb[0] == ts0);
423+
auto s01 = src0->nb[1] / ts0;
424+
auto s02 = src0->nb[2] / ts0;
425+
auto s03 = src0->nb[3] / ts0;
426+
fused_rms_norm_f32_nc_cuda(src0_d, src1_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream);
427+
}
299428
}

src/llama.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13390,7 +13390,7 @@ struct llm_build_context {
1339013390
ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
1339113391
cb(k_pe, "k_pe", il);
1339213392

13393-
kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
13393+
//kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
1339413394
kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
1339513395
model.layers[il].attn_kv_a_norm, NULL,
1339613396
LLM_NORM_RMS, cb, il);
@@ -13422,7 +13422,7 @@ struct llm_build_context {
1342213422
0);
1342313423
cb(v_states, "v_states", il);
1342413424

13425-
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
13425+
//q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
1342613426
q_pe = ggml_rope_ext(
1342713427
ctx0, q_pe, inp_pos, nullptr,
1342813428
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -13431,7 +13431,7 @@ struct llm_build_context {
1343113431
cb(q_pe, "q_pe", il);
1343213432

1343313433
// shared RoPE key
13434-
k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
13434+
//k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
1343513435
k_pe = ggml_rope_ext(
1343613436
ctx0, k_pe, inp_pos, nullptr,
1343713437
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,

0 commit comments

Comments
 (0)