Skip to content

Commit 6dec189

Browse files
committed
sycl: optimize rms_norm_back
1 parent db7f078 commit 6dec189

File tree

1 file changed

+68
-47
lines changed

1 file changed

+68
-47
lines changed

ggml/src/ggml-sycl/norm.cpp

Lines changed: 68 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -496,45 +496,40 @@ void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * d
496496
float * dx_base = static_cast< float *>(dst->data);
497497

498498
const int64_t D = dst->ne[0];
499-
const int64_t n1 = dst->ne[1], n2 = dst->ne[2], n3 = dst->ne[3];
500-
(void) n3;
499+
const int64_t n1 = dst->ne[1], n2 = dst->ne[2], n3 = dst->ne[3]; (void) n3;
501500
const int64_t N = ggml_nrows(dst);
502501
if (D == 0 || N == 0) return;
503502

504503
const ggml_tensor *G = dst->src[0];
505504
const ggml_tensor *X = dst->src[1];
506505
const int ts = (int) ggml_type_size(X->type);
507-
GGML_ASSERT((size_t) X->nb[0] == (size_t) ts);
508-
GGML_ASSERT((size_t) G->nb[0] == (size_t) ts);
506+
GGML_ASSERT((size_t) X->nb[0] == (size_t) ts);
507+
GGML_ASSERT((size_t) G->nb[0] == (size_t) ts);
509508
GGML_ASSERT((size_t) dst->nb[0] == (size_t) ts);
510509

511-
const int64_t xs1 = X->nb[1] / ts;
512-
const int64_t xs2 = X->nb[2] / ts;
513-
const int64_t xs3 = X->nb[3] / ts;
514-
const int64_t gs1 = G->nb[1] / ts;
515-
const int64_t gs2 = G->nb[2] / ts;
516-
const int64_t gs3 = G->nb[3] / ts;
517-
const int64_t ds1 = dst->nb[1] / ts;
518-
const int64_t ds2 = dst->nb[2] / ts;
519-
const int64_t ds3 = dst->nb[3] / ts;
510+
const int64_t xs1 = X->nb[1] / ts, xs2 = X->nb[2] / ts, xs3 = X->nb[3] / ts;
511+
const int64_t gs1 = G->nb[1] / ts, gs2 = G->nb[2] / ts, gs3 = G->nb[3] / ts;
512+
const int64_t ds1 = dst->nb[1] / ts, ds2 = dst->nb[2] / ts, ds3 = dst->nb[3] / ts;
520513

521514
dpct::queue_ptr q = ctx.stream();
522515

523-
// work-group size: multiple of WARP_SIZE, capped by device and 256
516+
// work-group size: multiple of WARP_SIZE, capped by device and 256, and not larger than D
524517
const int device_max_wg = ggml_sycl_info().max_work_group_sizes[ctx.device];
525518
auto roundup = [](int v, int m) { return ((v + m - 1) / m) * m; };
526519
int wg_cap = 256;
527520
if (device_max_wg > 0) wg_cap = std::min(wg_cap, device_max_wg);
528521
int WG = std::max(WARP_SIZE, std::min(roundup((int)std::min<int64_t>(D, wg_cap), WARP_SIZE), wg_cap));
529522

530-
// FP32 compensated reduction
523+
// FP32 path: per-thread compensated accumulation + hierarchical reduction
531524
q->submit([&](sycl::handler &cgh) {
532-
auto l_inv_r = sycl::local_accessor<float, 1>(sycl::range<1>(1), cgh);
533-
auto l_coeff = sycl::local_accessor<float, 1>(sycl::range<1>(1), cgh);
534-
auto l_part = sycl::local_accessor<sycl::float2, 1>(sycl::range<1>(std::max(1, WG / WARP_SIZE)), cgh);
525+
const int nwarps_loc = std::max(1, WG / WARP_SIZE);
526+
// store one partial value per warp (xx and xg) for cross-warp reduction
527+
auto l_xx = sycl::local_accessor<sycl::float2, 1>(sycl::range<1>(nwarps_loc), cgh);
528+
auto l_xg = sycl::local_accessor<sycl::float2, 1>(sycl::range<1>(nwarps_loc), cgh);
535529

536530
cgh.parallel_for(
537-
sycl::nd_range<3>(sycl::range<3>(1, 1, N) * sycl::range<3>(1, 1, WG), sycl::range<3>(1, 1, WG)),
531+
sycl::nd_range<3>(sycl::range<3>(1, 1, N) * sycl::range<3>(1, 1, WG),
532+
sycl::range<3>(1, 1, WG)),
538533
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
539534
const int row = item_ct1.get_group(2);
540535
const int tid = item_ct1.get_local_id(2);
@@ -547,29 +542,49 @@ void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * d
547542
const float *__restrict g_row = g_base + i3 * gs3 + i2 * gs2 + i1 * gs1;
548543
float *__restrict d_row = dx_base + i3 * ds3 + i2 * ds2 + i1 * ds1;
549544

550-
// per-thread compensated sums for sum(x^2) and sum(x*dz)
551-
float sum_xx = 0.f, c_xx = 0.f;
552-
float sum_xg = 0.f, c_xg = 0.f;
545+
// per-thread accumulation (compensated by default)
546+
float sum_xx = 0.f, sum_xg = 0.f;
547+
#ifndef GGML_SYCL_RMS_BACK_FAST
548+
float c_xx = 0.f, c_xg = 0.f;
549+
#endif
553550
for (int64_t col = tid; col < D; col += WG) {
554551
const float xv = x_row[col];
555552
const float gv = g_row[col];
556-
float y1 = xv * xv - c_xx; // compensated add for x^2
557-
float t1 = sum_xx + y1;
553+
#ifdef GGML_SYCL_RMS_BACK_FAST
554+
sum_xx += xv * xv;
555+
sum_xg += xv * gv;
556+
#else
557+
float y1 = xv * xv - c_xx;
558+
float t1 = sum_xx + y1;
558559
c_xx = (t1 - sum_xx) - y1;
559560
sum_xx = t1;
560-
float y2 = xv * gv - c_xg; // compensated add for x*dz
561-
float t2 = sum_xg + y2;
561+
562+
float y2 = xv * gv - c_xg;
563+
float t2 = sum_xg + y2;
562564
c_xg = (t2 - sum_xg) - y2;
563565
sum_xg = t2;
566+
#endif
564567
}
565568

566-
// reduce within warp
567-
sycl::float2 xx = sycl::float2(sum_xx, c_xx);
568-
sycl::float2 xg = sycl::float2(sum_xg, c_xg);
569+
// warp-level reduction
570+
sycl::float2 xx = sycl::float2(sum_xx,
571+
#ifndef GGML_SYCL_RMS_BACK_FAST
572+
c_xx
573+
#else
574+
0.f
575+
#endif
576+
);
577+
sycl::float2 xg = sycl::float2(sum_xg,
578+
#ifndef GGML_SYCL_RMS_BACK_FAST
579+
c_xg
580+
#else
581+
0.f
582+
#endif
583+
);
569584
xx = warp_reduce_sum(xx, item_ct1);
570585
xg = warp_reduce_sum(xg, item_ct1);
571586

572-
// cross-warp reduction if needed
587+
// cross-warp reduction using local memory (single barrier)
573588
const auto sub_group = item_ct1.get_sub_group();
574589
const auto sg_id = sub_group.get_group_linear_id();
575590
const auto wi_in_sg = sub_group.get_local_linear_id();
@@ -579,34 +594,40 @@ void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * d
579594
sycl::float2 xx_total = xx;
580595
sycl::float2 xg_total = xg;
581596
if (nwarps > 1) {
582-
if (wi_in_sg == 0) l_part[sg_id] = xx;
597+
if (wi_in_sg == 0) {
598+
l_xx[sg_id] = xx;
599+
l_xg[sg_id] = xg;
600+
}
583601
item_ct1.barrier(sycl::access::fence_space::local_space);
584602

585-
xx_total = sycl::float2(0.f, 0.f);
586-
const size_t nreduce = ceil_div(nwarps, WARP_SIZE);
587-
for (size_t i = 0; i < nreduce; ++i) xx_total += l_part[wi_in_sg + i * WARP_SIZE];
588-
xx_total = warp_reduce_sum(xx_total, item_ct1);
589-
590-
if (wi_in_sg == 0) l_part[sg_id] = xg;
591-
item_ct1.barrier(sycl::access::fence_space::local_space);
592-
xg_total = sycl::float2(0.f, 0.f);
593-
for (size_t i = 0; i < nreduce; ++i) xg_total += l_part[wi_in_sg + i * WARP_SIZE];
594-
xg_total = warp_reduce_sum(xg_total, item_ct1);
603+
if (sg_id == 0) {
604+
const unsigned wi_u = wi_in_sg;
605+
sycl::float2 xx_first = (wi_u < static_cast<unsigned>(nwarps)) ? l_xx[wi_u] : sycl::float2(0.f, 0.f);
606+
sycl::float2 xg_first = (wi_u < static_cast<unsigned>(nwarps)) ? l_xg[wi_u] : sycl::float2(0.f, 0.f);
607+
xx_total = warp_reduce_sum(xx_first, item_ct1);
608+
xg_total = warp_reduce_sum(xg_first, item_ct1);
609+
} else {
610+
// other subgroups keep their local totals; they'll be ignored
611+
xx_total = xx;
612+
xg_total = xg;
613+
}
614+
// ensure all threads see the first-subgroup result via broadcast below
595615
}
596616

597-
// compute inv_r and coeff once per row
617+
// compute inv_r and coeff once per row and broadcast to the whole work-group
618+
float inv_r = 0.f;
619+
float coeff = 0.f;
598620
if (tid == 0) {
599621
const float sum_xx_f = xx_total.x() + xx_total.y();
600622
const float sum_xdz_f = xg_total.x() + xg_total.y();
601623
const float mean_eps = sum_xx_f / (float) D + eps;
602624
const float sum_eps = sum_xx_f + eps * (float) D;
603-
l_inv_r[0] = sycl::rsqrt(mean_eps);
604-
l_coeff[0] = -sum_xdz_f / sum_eps;
625+
inv_r = sycl::rsqrt(mean_eps);
626+
coeff = -sum_xdz_f / sum_eps;
605627
}
628+
inv_r = sycl::group_broadcast(item_ct1.get_group(), inv_r);
629+
coeff = sycl::group_broadcast(item_ct1.get_group(), coeff);
606630

607-
item_ct1.barrier(sycl::access::fence_space::local_space);
608-
const float inv_r = l_inv_r[0];
609-
const float coeff = l_coeff[0];
610631
for (int64_t col = tid; col < D; col += WG) {
611632
d_row[col] = (g_row[col] + coeff * x_row[col]) * inv_r;
612633
}

0 commit comments

Comments
 (0)