@@ -496,45 +496,40 @@ void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * d
496496 float * dx_base = static_cast < float *>(dst->data );
497497
498498 const int64_t D = dst->ne [0 ];
499- const int64_t n1 = dst->ne [1 ], n2 = dst->ne [2 ], n3 = dst->ne [3 ];
500- (void ) n3;
499+ const int64_t n1 = dst->ne [1 ], n2 = dst->ne [2 ], n3 = dst->ne [3 ]; (void ) n3;
501500 const int64_t N = ggml_nrows (dst);
502501 if (D == 0 || N == 0 ) return ;
503502
504503 const ggml_tensor *G = dst->src [0 ];
505504 const ggml_tensor *X = dst->src [1 ];
506505 const int ts = (int ) ggml_type_size (X->type );
507- GGML_ASSERT ((size_t ) X->nb [0 ] == (size_t ) ts);
508- GGML_ASSERT ((size_t ) G->nb [0 ] == (size_t ) ts);
506+ GGML_ASSERT ((size_t ) X->nb [0 ] == (size_t ) ts);
507+ GGML_ASSERT ((size_t ) G->nb [0 ] == (size_t ) ts);
509508 GGML_ASSERT ((size_t ) dst->nb [0 ] == (size_t ) ts);
510509
511- const int64_t xs1 = X->nb [1 ] / ts;
512- const int64_t xs2 = X->nb [2 ] / ts;
513- const int64_t xs3 = X->nb [3 ] / ts;
514- const int64_t gs1 = G->nb [1 ] / ts;
515- const int64_t gs2 = G->nb [2 ] / ts;
516- const int64_t gs3 = G->nb [3 ] / ts;
517- const int64_t ds1 = dst->nb [1 ] / ts;
518- const int64_t ds2 = dst->nb [2 ] / ts;
519- const int64_t ds3 = dst->nb [3 ] / ts;
510+ const int64_t xs1 = X->nb [1 ] / ts, xs2 = X->nb [2 ] / ts, xs3 = X->nb [3 ] / ts;
511+ const int64_t gs1 = G->nb [1 ] / ts, gs2 = G->nb [2 ] / ts, gs3 = G->nb [3 ] / ts;
512+ const int64_t ds1 = dst->nb [1 ] / ts, ds2 = dst->nb [2 ] / ts, ds3 = dst->nb [3 ] / ts;
520513
521514 dpct::queue_ptr q = ctx.stream ();
522515
523- // work-group size: multiple of WARP_SIZE, capped by device and 256
516+ // work-group size: multiple of WARP_SIZE, capped by device and 256, and not larger than D
524517 const int device_max_wg = ggml_sycl_info ().max_work_group_sizes [ctx.device ];
525518 auto roundup = [](int v, int m) { return ((v + m - 1 ) / m) * m; };
526519 int wg_cap = 256 ;
527520 if (device_max_wg > 0 ) wg_cap = std::min (wg_cap, device_max_wg);
528521 int WG = std::max (WARP_SIZE, std::min (roundup ((int )std::min<int64_t >(D, wg_cap), WARP_SIZE), wg_cap));
529522
530- // FP32 compensated reduction
523+ // FP32 path: per-thread compensated accumulation + hierarchical reduction
531524 q->submit ([&](sycl::handler &cgh) {
532- auto l_inv_r = sycl::local_accessor<float , 1 >(sycl::range<1 >(1 ), cgh);
533- auto l_coeff = sycl::local_accessor<float , 1 >(sycl::range<1 >(1 ), cgh);
534- auto l_part = sycl::local_accessor<sycl::float2, 1 >(sycl::range<1 >(std::max (1 , WG / WARP_SIZE)), cgh);
525+ const int nwarps_loc = std::max (1 , WG / WARP_SIZE);
526+ // store one partial value per warp (xx and xg) for cross-warp reduction
527+ auto l_xx = sycl::local_accessor<sycl::float2, 1 >(sycl::range<1 >(nwarps_loc), cgh);
528+ auto l_xg = sycl::local_accessor<sycl::float2, 1 >(sycl::range<1 >(nwarps_loc), cgh);
535529
536530 cgh.parallel_for (
537- sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , N) * sycl::range<3 >(1 , 1 , WG), sycl::range<3 >(1 , 1 , WG)),
531+ sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , N) * sycl::range<3 >(1 , 1 , WG),
532+ sycl::range<3 >(1 , 1 , WG)),
538533 [=](sycl::nd_item<3 > item_ct1) [[sycl::reqd_sub_group_size (WARP_SIZE)]] {
539534 const int row = item_ct1.get_group (2 );
540535 const int tid = item_ct1.get_local_id (2 );
@@ -547,29 +542,49 @@ void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * d
547542 const float *__restrict g_row = g_base + i3 * gs3 + i2 * gs2 + i1 * gs1;
548543 float *__restrict d_row = dx_base + i3 * ds3 + i2 * ds2 + i1 * ds1;
549544
550- // per-thread compensated sums for sum(x^2) and sum(x*dz)
551- float sum_xx = 0 .f , c_xx = 0 .f ;
552- float sum_xg = 0 .f , c_xg = 0 .f ;
545+ // per-thread accumulation (compensated by default)
546+ float sum_xx = 0 .f , sum_xg = 0 .f ;
547+ #ifndef GGML_SYCL_RMS_BACK_FAST
548+ float c_xx = 0 .f , c_xg = 0 .f ;
549+ #endif
553550 for (int64_t col = tid; col < D; col += WG) {
554551 const float xv = x_row[col];
555552 const float gv = g_row[col];
556- float y1 = xv * xv - c_xx; // compensated add for x^2
557- float t1 = sum_xx + y1;
553+ #ifdef GGML_SYCL_RMS_BACK_FAST
554+ sum_xx += xv * xv;
555+ sum_xg += xv * gv;
556+ #else
557+ float y1 = xv * xv - c_xx;
558+ float t1 = sum_xx + y1;
558559 c_xx = (t1 - sum_xx) - y1;
559560 sum_xx = t1;
560- float y2 = xv * gv - c_xg; // compensated add for x*dz
561- float t2 = sum_xg + y2;
561+
562+ float y2 = xv * gv - c_xg;
563+ float t2 = sum_xg + y2;
562564 c_xg = (t2 - sum_xg) - y2;
563565 sum_xg = t2;
566+ #endif
564567 }
565568
566- // reduce within warp
567- sycl::float2 xx = sycl::float2 (sum_xx, c_xx);
568- sycl::float2 xg = sycl::float2 (sum_xg, c_xg);
569+ // warp-level reduction
570+ sycl::float2 xx = sycl::float2 (sum_xx,
571+ #ifndef GGML_SYCL_RMS_BACK_FAST
572+ c_xx
573+ #else
574+ 0 .f
575+ #endif
576+ );
577+ sycl::float2 xg = sycl::float2 (sum_xg,
578+ #ifndef GGML_SYCL_RMS_BACK_FAST
579+ c_xg
580+ #else
581+ 0 .f
582+ #endif
583+ );
569584 xx = warp_reduce_sum (xx, item_ct1);
570585 xg = warp_reduce_sum (xg, item_ct1);
571586
572- // cross-warp reduction if needed
587+ // cross-warp reduction using local memory (single barrier)
573588 const auto sub_group = item_ct1.get_sub_group ();
574589 const auto sg_id = sub_group.get_group_linear_id ();
575590 const auto wi_in_sg = sub_group.get_local_linear_id ();
@@ -579,34 +594,40 @@ void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * d
579594 sycl::float2 xx_total = xx;
580595 sycl::float2 xg_total = xg;
581596 if (nwarps > 1 ) {
582- if (wi_in_sg == 0 ) l_part[sg_id] = xx;
597+ if (wi_in_sg == 0 ) {
598+ l_xx[sg_id] = xx;
599+ l_xg[sg_id] = xg;
600+ }
583601 item_ct1.barrier (sycl::access::fence_space::local_space);
584602
585- xx_total = sycl::float2 (0 .f , 0 .f );
586- const size_t nreduce = ceil_div (nwarps, WARP_SIZE);
587- for (size_t i = 0 ; i < nreduce; ++i) xx_total += l_part[wi_in_sg + i * WARP_SIZE];
588- xx_total = warp_reduce_sum (xx_total, item_ct1);
589-
590- if (wi_in_sg == 0 ) l_part[sg_id] = xg;
591- item_ct1.barrier (sycl::access::fence_space::local_space);
592- xg_total = sycl::float2 (0 .f , 0 .f );
593- for (size_t i = 0 ; i < nreduce; ++i) xg_total += l_part[wi_in_sg + i * WARP_SIZE];
594- xg_total = warp_reduce_sum (xg_total, item_ct1);
603+ if (sg_id == 0 ) {
604+ const unsigned wi_u = wi_in_sg;
605+ sycl::float2 xx_first = (wi_u < static_cast <unsigned >(nwarps)) ? l_xx[wi_u] : sycl::float2 (0 .f , 0 .f );
606+ sycl::float2 xg_first = (wi_u < static_cast <unsigned >(nwarps)) ? l_xg[wi_u] : sycl::float2 (0 .f , 0 .f );
607+ xx_total = warp_reduce_sum (xx_first, item_ct1);
608+ xg_total = warp_reduce_sum (xg_first, item_ct1);
609+ } else {
610+ // other subgroups keep their local totals; they'll be ignored
611+ xx_total = xx;
612+ xg_total = xg;
613+ }
614+ // ensure all threads see the first-subgroup result via broadcast below
595615 }
596616
597- // compute inv_r and coeff once per row
617+ // compute inv_r and coeff once per row and broadcast to the whole work-group
618+ float inv_r = 0 .f ;
619+ float coeff = 0 .f ;
598620 if (tid == 0 ) {
599621 const float sum_xx_f = xx_total.x () + xx_total.y ();
600622 const float sum_xdz_f = xg_total.x () + xg_total.y ();
601623 const float mean_eps = sum_xx_f / (float ) D + eps;
602624 const float sum_eps = sum_xx_f + eps * (float ) D;
603- l_inv_r[ 0 ] = sycl::rsqrt (mean_eps);
604- l_coeff[ 0 ] = -sum_xdz_f / sum_eps;
625+ inv_r = sycl::rsqrt (mean_eps);
626+ coeff = -sum_xdz_f / sum_eps;
605627 }
628+ inv_r = sycl::group_broadcast (item_ct1.get_group (), inv_r);
629+ coeff = sycl::group_broadcast (item_ct1.get_group (), coeff);
606630
607- item_ct1.barrier (sycl::access::fence_space::local_space);
608- const float inv_r = l_inv_r[0 ];
609- const float coeff = l_coeff[0 ];
610631 for (int64_t col = tid; col < D; col += WG) {
611632 d_row[col] = (g_row[col] + coeff * x_row[col]) * inv_r;
612633 }
0 commit comments