Skip to content

Commit bc49ef6

Browse files
committed
Revert "Fused Q and K fused_rms_norm for TG on CUDA (ikawrakow#882)"
This reverts commit 8c8a7fb.
1 parent d165961 commit bc49ef6

File tree

5 files changed

+2
-94
lines changed

5 files changed

+2
-94
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3244,15 +3244,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
32443244
ggml_cuda_op_rms_norm(ctx, dst);
32453245
break;
32463246
case GGML_OP_FUSED_RMS_NORM:
3247-
if (i + 2 < cgraph->n_nodes &&
3248-
cgraph->nodes[i+1]->op == GGML_OP_VIEW &&
3249-
cgraph->nodes[i+2]->op == GGML_OP_FUSED_RMS_NORM &&
3250-
dst->ne[2] == 1 && cgraph->nodes[i+2]->ne[2] == 1) {
3251-
ggml_cuda_op_fused_rms_rms_norm(ctx, dst, cgraph->nodes[i+2]);
3252-
i += 2;
3253-
} else {
3254-
ggml_cuda_op_fused_rms_norm(ctx, dst);
3255-
}
3247+
ggml_cuda_op_fused_rms_norm(ctx, dst);
32563248
break;
32573249
case GGML_OP_MUL_MAT:
32583250
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {

ggml/src/ggml-cuda/norm.cu

Lines changed: 0 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -619,84 +619,3 @@ void ggml_cuda_op_fused_add_add_rms_norm(ggml_backend_cuda_context & ctx,
619619
fused_add_add_rms_norm_f32_cuda((const float *)add1->src[0]->data, (const float *)add1->src[1]->data, (const float *)add2->src[1]->data,
620620
src1_d, (float *)add2->data, dst_d, ne00, nrows, eps, stream);
621621
}
622-
623-
template <int block_size>
624-
static __global__ void fused_rms_rms_norm_f32(int ncols, int nrows1, int nrows2, size_t nb1, size_t nb2, float eps,
625-
const char *x1, const char * x2, const float * c1, const float * c2, float * y1, float * y2) {
626-
const int row = blockIdx.x*blockDim.y + threadIdx.y;
627-
const int tid = threadIdx.x;
628-
629-
auto x_row = (const float *)(row < nrows1 ? x1 + row*nb1 : x2 + (row - nrows1)*nb2);
630-
631-
float tmp = 0.0f; // partial sum for thread in warp
632-
633-
for (int col = tid; col < ncols; col += block_size) {
634-
const float xi = x_row[col];
635-
tmp += xi * xi;
636-
}
637-
638-
// sum up partial sums
639-
tmp = warp_reduce_sum(tmp);
640-
if (block_size > WARP_SIZE) {
641-
__shared__ float s_sum[32];
642-
int warp_id = threadIdx.x / WARP_SIZE;
643-
int lane_id = threadIdx.x % WARP_SIZE;
644-
if (lane_id == 0) {
645-
s_sum[warp_id] = tmp;
646-
}
647-
__syncthreads();
648-
tmp = lane_id < block_size/WARP_SIZE ? s_sum[lane_id] : 0.0f;
649-
tmp = warp_reduce_sum(tmp);
650-
}
651-
652-
const float mean = tmp / ncols;
653-
const float scale = rsqrtf(mean + eps);
654-
655-
auto dst = row < nrows1 ? y1 + row*ncols : y2 + (row - nrows1)*ncols;
656-
auto c = row < nrows1 ? c1 : c2;
657-
658-
for (int col = tid; col < ncols; col += block_size) {
659-
dst[col] = scale * c[col] * x_row[col];
660-
}
661-
}
662-
663-
static void fused_rms_rms_norm_f32_cuda(int ncols, int nrows1, int nrows2, size_t nb1, size_t nb2, float eps,
664-
const char * x1, const char * x2, const float * c1, const float * c2, float * y1, float * y2, cudaStream_t stream) {
665-
GGML_ASSERT(ncols % WARP_SIZE == 0);
666-
int nrows = nrows1 + nrows2;
667-
if (ncols < 1024) {
668-
const dim3 block_dims(256, 1, 1);
669-
fused_rms_rms_norm_f32<256><<<nrows, block_dims, 0, stream>>>(ncols, nrows1, nrows2, nb1, nb2, eps, x1, x2, c1, c2, y1, y2);
670-
} else {
671-
const dim3 block_dims(1024, 1, 1);
672-
fused_rms_rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(ncols, nrows1, nrows2, nb1, nb2, eps, x1, x2, c1, c2, y1, y2);
673-
}
674-
}
675-
676-
void ggml_cuda_op_fused_rms_rms_norm([[maybe_unused]] ggml_backend_cuda_context & ctx, [[maybe_unused]] ggml_tensor * rms1, [[maybe_unused]] ggml_tensor * rms2) {
677-
GGML_ASSERT(rms1->ne[2] == 1 && rms1->ne[3] == 1);
678-
GGML_ASSERT(rms2->ne[2] == 1 && rms2->ne[3] == 1);
679-
GGML_ASSERT(rms1->ne[0] == rms2->ne[0]);
680-
GGML_ASSERT(rms1->type == GGML_TYPE_F32);
681-
GGML_ASSERT(rms2->type == GGML_TYPE_F32);
682-
GGML_ASSERT(rms1->src[0]->type == GGML_TYPE_F32);
683-
GGML_ASSERT(rms2->src[0]->type == GGML_TYPE_F32);
684-
GGML_ASSERT(rms1->src[0]->ne[0] == rms1->src[1]->ne[0]);
685-
GGML_ASSERT(rms2->src[0]->ne[0] == rms2->src[1]->ne[0]);
686-
GGML_ASSERT(ggml_nrows(rms1->src[1]) == 1);
687-
GGML_ASSERT(ggml_nrows(rms2->src[1]) == 1);
688-
GGML_ASSERT(rms1->src[1]->type == GGML_TYPE_F32);
689-
GGML_ASSERT(rms2->src[1]->type == GGML_TYPE_F32);
690-
691-
float eps1, eps2;
692-
memcpy(&eps1, rms1->op_params, sizeof(float));
693-
memcpy(&eps2, rms2->op_params, sizeof(float));
694-
GGML_ASSERT(eps1 == eps2);
695-
696-
fused_rms_rms_norm_f32_cuda(rms1->ne[0], rms1->ne[1], rms2->ne[1], rms1->nb[1], rms2->nb[1], eps1,
697-
(const char *)rms1->src[0]->data, (const char *)rms2->src[0]->data,
698-
(const float *)rms1->src[1]->data, (const float *)rms2->src[1]->data,
699-
(float *)rms1->data, (float *)rms2->data, ctx.stream());
700-
701-
702-
}

ggml/src/ggml-cuda/norm.cuh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,3 @@ void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor *
1111
void ggml_cuda_op_fused_add_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * add, ggml_tensor * dst);
1212

1313
void ggml_cuda_op_fused_add_add_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * add1, ggml_tensor * add2, ggml_tensor * dst);
14-
15-
void ggml_cuda_op_fused_rms_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * rms1, ggml_tensor * rms2);

src/llama-build-context.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,12 +1280,10 @@ std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> llm_build_context::llm_buil
12801280
if (q_norm) {
12811281
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
12821282
cb(Qcur, "Qcur_normed", il);
1283-
ggml_build_forward_expand(gf, Qcur);
12841283
}
12851284
if (k_norm) {
12861285
Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
12871286
cb(Kcur, "Kcur_normed", il);
1288-
ggml_build_forward_expand(gf, Kcur);
12891287
}
12901288

12911289
return {Qcur, Kcur, Vcur};

src/llama-load-tensors.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2470,6 +2470,7 @@ bool create_tensors_helper::merge_qkv(const LLM_TN & tn, int i, int bias) {
24702470
layer.wk = ml.create_tensor_as_view(ctx_split, layer.wqkv, wk_name.c_str(), { wk->ne[0], wk->ne[1] }, wq->ne[1]*wq->nb[1]);
24712471
layer.wv = ml.create_tensor_as_view(ctx_split, layer.wqkv, wv_name.c_str(), { wv->ne[0], wv->ne[1] }, wq->ne[1]*wq->nb[1] + wk->ne[1]*wk->nb[1] );
24722472
fused_qkv = true;
2473+
printf("================================== Created merged qkv %s\n", layer.wqkv->name);
24732474
if (bias) {
24742475
auto bq_name = tn(LLM_TENSOR_ATTN_Q, "bias", i);
24752476
auto bk_name = tn(LLM_TENSOR_ATTN_K, "bias", i);

0 commit comments

Comments
 (0)