Skip to content

Commit ebac831

Browse files
committed
Merge branch 'master' into xsn/fix-mrope-causal
2 parents c3e1393 + f549b00 commit ebac831

File tree

17 files changed

+269
-69
lines changed

17 files changed

+269
-69
lines changed

docs/ops.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ Legend:
7979
| REPEAT |||| 🟡 || 🟡 || 🟡 ||
8080
| REPEAT_BACK ||||||||||
8181
| RMS_NORM ||||| 🟡 |||||
82-
| RMS_NORM_BACK ||||||| |||
82+
| RMS_NORM_BACK ||||||| |||
8383
| RMS_NORM_MUL_ADD ||||||||||
8484
| ROLL ||||||||||
8585
| ROPE || 🟡 ||||||||

docs/ops/SYCL.csv

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5637,25 +5637,25 @@
56375637
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000000,inplace=0","support","1","yes","SYCL"
56385638
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000000","support","1","yes","SYCL"
56395639
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000000,inplace=0","support","1","yes","SYCL"
5640-
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000000","support","0","no","SYCL"
5640+
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000000","support","1","yes","SYCL"
56415641
"SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL"
56425642
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001","support","1","yes","SYCL"
56435643
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001,inplace=0","support","1","yes","SYCL"
56445644
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000001","support","1","yes","SYCL"
56455645
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000001,inplace=0","support","1","yes","SYCL"
5646-
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000001","support","0","no","SYCL"
5646+
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000001","support","1","yes","SYCL"
56475647
"SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL"
56485648
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000100","support","1","yes","SYCL"
56495649
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000100,inplace=0","support","1","yes","SYCL"
56505650
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000100","support","1","yes","SYCL"
56515651
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000100,inplace=0","support","1","yes","SYCL"
5652-
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000100","support","0","no","SYCL"
5652+
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000100","support","1","yes","SYCL"
56535653
"SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL"
56545654
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.100000","support","1","yes","SYCL"
56555655
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.100000,inplace=0","support","1","yes","SYCL"
56565656
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.100000","support","1","yes","SYCL"
56575657
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.100000,inplace=0","support","1","yes","SYCL"
5658-
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.100000","support","0","no","SYCL"
5658+
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.100000","support","1","yes","SYCL"
56595659
"SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL"
56605660
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001,inplace=1","support","1","yes","SYCL"
56615661
"SYCL0","RMS_NORM_MUL_ADD","type=f32,ne=[64,5,4,3],eps=0.000000,broadcast=0,multi_add=0","support","1","yes","SYCL"

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
#include "ggml-cuda/upscale.cuh"
5151
#include "ggml-cuda/wkv.cuh"
5252
#include "ggml-cuda/gla.cuh"
53+
#include "ggml-cuda/set.cuh"
5354
#include "ggml-cuda/set-rows.cuh"
5455
#include "ggml-cuda/pad_reflect_1d.cuh"
5556
#include "ggml.h"
@@ -2416,6 +2417,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
24162417
case GGML_OP_SET_ROWS:
24172418
ggml_cuda_op_set_rows(ctx, dst);
24182419
break;
2420+
case GGML_OP_SET:
2421+
ggml_cuda_op_set(ctx, dst);
2422+
break;
24192423
case GGML_OP_DUP:
24202424
ggml_cuda_dup(ctx, dst);
24212425
break;
@@ -2974,7 +2978,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
29742978
ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
29752979

29762980
if (ops.size() == topk_moe_ops_with_norm.size() &&
2977-
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 8 })) {
2981+
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
29782982
ggml_tensor * softmax = cgraph->nodes[node_idx];
29792983
ggml_tensor * weights = cgraph->nodes[node_idx + 9];
29802984

@@ -2993,7 +2997,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
29932997
}
29942998

29952999
if (ops.size() == topk_moe_ops_delayed_softmax.size() &&
2996-
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2, node_idx + 5 })) {
3000+
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
29973001
ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
29983002
ggml_tensor * weights = cgraph->nodes[node_idx + 5];
29993003

@@ -3114,9 +3118,20 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
31143118
// With the use of CUDA graphs, the execution will be performed by the graph launch.
31153119
if (!use_cuda_graph || cuda_graph_update_required) {
31163120

3121+
[[maybe_unused]] int prev_i = 0;
3122+
31173123
for (int i = 0; i < cgraph->n_nodes; i++) {
31183124
ggml_tensor * node = cgraph->nodes[i];
31193125

3126+
3127+
#ifdef GGML_CUDA_DEBUG
3128+
const int nodes_fused = i - prev_i - 1;
3129+
prev_i = i;
3130+
if (nodes_fused > 0) {
3131+
GGML_LOG_INFO("nodes_fused: %d\n", nodes_fused);
3132+
}
3133+
#endif
3134+
31203135
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
31213136
continue;
31223137
}
@@ -3842,6 +3857,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
38423857
op->src[0]->type == GGML_TYPE_F32 &&
38433858
(op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32);
38443859
} break;
3860+
case GGML_OP_SET:
3861+
{
3862+
const ggml_type t = op->type;
3863+
return (t == GGML_TYPE_F32 || t == GGML_TYPE_I32) &&
3864+
t == op->src[0]->type &&
3865+
t == op->src[1]->type;
3866+
} break;
38453867
case GGML_OP_CPY:
38463868
{
38473869
ggml_type src0_type = op->src[0]->type;

ggml/src/ggml-cuda/set.cu

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#include "set.cuh"
2+
#include "cpy.cuh"
3+
4+
void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
5+
const ggml_tensor * src0 = dst->src[0];
6+
const ggml_tensor * src1 = dst->src[1];
7+
8+
GGML_ASSERT((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32));
9+
GGML_ASSERT(src1->type == src0->type);
10+
GGML_ASSERT(dst ->type == src0->type);
11+
12+
GGML_ASSERT(ggml_is_contiguous(dst));
13+
GGML_ASSERT(ggml_is_contiguous(src0));
14+
GGML_ASSERT(ggml_is_contiguous(src1));
15+
16+
const size_t nb1 = ((int32_t *) dst->op_params)[0];
17+
const size_t nb2 = ((int32_t *) dst->op_params)[1];
18+
const size_t nb3 = ((int32_t *) dst->op_params)[2];
19+
const size_t offset = ((int32_t *) dst->op_params)[3];
20+
const bool inplace= (bool) ((int32_t *) dst->op_params)[4];
21+
22+
if (!inplace) {
23+
ggml_cuda_cpy(ctx, src0, dst);
24+
}
25+
26+
ggml_tensor dst_view = *dst;
27+
dst_view.data = (void *)((char *)dst->data + offset);
28+
dst_view.ne[0] = src1->ne[0];
29+
dst_view.ne[1] = src1->ne[1];
30+
dst_view.ne[2] = src1->ne[2];
31+
dst_view.ne[3] = src1->ne[3];
32+
33+
dst_view.nb[0] = ggml_element_size(dst);
34+
dst_view.nb[1] = nb1;
35+
dst_view.nb[2] = nb2;
36+
dst_view.nb[3] = nb3;
37+
38+
ggml_cuda_cpy(ctx, src1, &dst_view);
39+
}

ggml/src/ggml-cuda/set.cuh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#pragma once
2+
3+
#include "common.cuh"
4+
5+
#define CUDA_SET_BLOCK_SIZE 256
6+
7+
void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

ggml/src/ggml-hexagon/ggml-hexagon.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ static inline void hex_format_op_names(char * str, const struct ggml_tensor * t)
211211
// ** backend sessions
212212

213213
struct ggml_hexagon_session {
214-
ggml_hexagon_session(int dev_id) noexcept(false);
214+
ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false);
215215
~ggml_hexagon_session() noexcept(true);
216216

217217
void allocate(int dev_id) noexcept(false);
@@ -1631,10 +1631,13 @@ void ggml_hexagon_session::release() noexcept(true) {
16311631
}
16321632
}
16331633

1634-
ggml_hexagon_session::ggml_hexagon_session(int dev_id) noexcept(false) {
1634+
ggml_hexagon_session::ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false) {
16351635
buffer_type.context = nullptr;
16361636
repack_buffer_type.context = nullptr;
16371637

1638+
buffer_type.device = dev;
1639+
repack_buffer_type.device = dev;
1640+
16381641
try {
16391642
allocate(dev_id);
16401643

@@ -3628,7 +3631,7 @@ ggml_hexagon_registry::ggml_hexagon_registry(ggml_backend_reg_t reg) {
36283631
devices[i].iface = ggml_backend_hexagon_device_i;
36293632
devices[i].reg = reg;
36303633
try {
3631-
devices[i].context = new ggml_hexagon_session(i);
3634+
devices[i].context = new ggml_hexagon_session(i, &devices[i]);
36323635
} catch (std::exception const &exc) {
36333636
GGML_LOG_ERROR("ggml-hex: failed to create device/session %zu\n", i);
36343637
devices[i].context = nullptr;

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
#include "ggml-sycl/backend.hpp"
4343
#include "ggml-sycl/common.hpp"
4444
#include "ggml-sycl/element_wise.hpp"
45+
#include "ggml-sycl/norm.hpp"
4546
#include "ggml-sycl/presets.hpp"
4647
#include "ggml-sycl/gemm.hpp"
4748
#include "ggml-sycl/set_rows.hpp"
@@ -2637,6 +2638,11 @@ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * ds
26372638
ggml_sycl_op_rms_norm(ctx, dst);
26382639
}
26392640

2641+
static void ggml_sycl_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2642+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
2643+
ggml_sycl_op_rms_norm_back(ctx, dst);
2644+
}
2645+
26402646
static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
26412647
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
26422648
ggml_sycl_op_l2_norm(ctx, dst);
@@ -3827,6 +3833,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
38273833
case GGML_OP_LEAKY_RELU:
38283834
ggml_sycl_leaky_relu(ctx, dst);
38293835
break;
3836+
case GGML_OP_RMS_NORM_BACK:
3837+
ggml_sycl_rms_norm_back(ctx, dst);
3838+
break;
38303839
case GGML_OP_RMS_NORM:
38313840
ggml_sycl_rms_norm(ctx, dst);
38323841
break;
@@ -4571,6 +4580,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
45714580
return ggml_is_contiguous(op->src[0]);
45724581
case GGML_OP_RMS_NORM:
45734582
return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
4583+
case GGML_OP_RMS_NORM_BACK:
4584+
return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
45744585
case GGML_OP_SCALE:
45754586
return true;
45764587
case GGML_OP_CONT:

ggml/src/ggml-sycl/norm.cpp

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,162 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
480480
rms_norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device);
481481
}
482482

483+
void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
484+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
485+
486+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); // dz
487+
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32); // x
488+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
489+
490+
float eps = 1e-5f;
491+
std::memcpy(&eps, dst->op_params, sizeof(float));
492+
if (!(eps > 0.0f) || !std::isfinite(eps)) eps = 1e-5f;
493+
494+
const float * g_base = static_cast<const float *>(dst->src[0]->data); // dz
495+
const float * x_base = static_cast<const float *>(dst->src[1]->data); // x
496+
float * dx_base = static_cast< float *>(dst->data);
497+
498+
const int64_t D = dst->ne[0];
499+
const int64_t n1 = dst->ne[1], n2 = dst->ne[2], n3 = dst->ne[3]; (void) n3;
500+
const int64_t N = ggml_nrows(dst);
501+
if (D == 0 || N == 0) return;
502+
503+
const ggml_tensor *G = dst->src[0];
504+
const ggml_tensor *X = dst->src[1];
505+
const int ts = (int) ggml_type_size(X->type);
506+
GGML_ASSERT((size_t) X->nb[0] == (size_t) ts);
507+
GGML_ASSERT((size_t) G->nb[0] == (size_t) ts);
508+
GGML_ASSERT((size_t) dst->nb[0] == (size_t) ts);
509+
510+
const int64_t xs1 = X->nb[1] / ts, xs2 = X->nb[2] / ts, xs3 = X->nb[3] / ts;
511+
const int64_t gs1 = G->nb[1] / ts, gs2 = G->nb[2] / ts, gs3 = G->nb[3] / ts;
512+
const int64_t ds1 = dst->nb[1] / ts, ds2 = dst->nb[2] / ts, ds3 = dst->nb[3] / ts;
513+
514+
dpct::queue_ptr q = ctx.stream();
515+
516+
// work-group size: multiple of WARP_SIZE, capped by device and 256, and not larger than D
517+
const int device_max_wg = ggml_sycl_info().max_work_group_sizes[ctx.device];
518+
auto roundup = [](int v, int m) { return ((v + m - 1) / m) * m; };
519+
int wg_cap = 256;
520+
if (device_max_wg > 0) wg_cap = std::min(wg_cap, device_max_wg);
521+
int WG = std::max(WARP_SIZE, std::min(roundup((int)std::min<int64_t>(D, wg_cap), WARP_SIZE), wg_cap));
522+
523+
// FP32 path: per-thread compensated accumulation + hierarchical reduction
524+
q->submit([&](sycl::handler &cgh) {
525+
const int nwarps_loc = std::max(1, WG / WARP_SIZE);
526+
// store one partial value per warp (xx and xg) for cross-warp reduction
527+
auto l_xx = sycl::local_accessor<sycl::float2, 1>(sycl::range<1>(nwarps_loc), cgh);
528+
auto l_xg = sycl::local_accessor<sycl::float2, 1>(sycl::range<1>(nwarps_loc), cgh);
529+
530+
cgh.parallel_for(
531+
sycl::nd_range<3>(sycl::range<3>(1, 1, N) * sycl::range<3>(1, 1, WG),
532+
sycl::range<3>(1, 1, WG)),
533+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
534+
const int row = item_ct1.get_group(2);
535+
const int tid = item_ct1.get_local_id(2);
536+
537+
const int64_t i1 = row % n1;
538+
const int64_t i2 = (row / n1) % n2;
539+
const int64_t i3 = row / (n1 * n2);
540+
541+
const float *__restrict x_row = x_base + i3 * xs3 + i2 * xs2 + i1 * xs1;
542+
const float *__restrict g_row = g_base + i3 * gs3 + i2 * gs2 + i1 * gs1;
543+
float *__restrict d_row = dx_base + i3 * ds3 + i2 * ds2 + i1 * ds1;
544+
545+
// per-thread accumulation (compensated by default)
546+
float sum_xx = 0.f, sum_xg = 0.f;
547+
#ifndef GGML_SYCL_RMS_BACK_FAST
548+
float c_xx = 0.f, c_xg = 0.f;
549+
#endif
550+
for (int64_t col = tid; col < D; col += WG) {
551+
const float xv = x_row[col];
552+
const float gv = g_row[col];
553+
#ifdef GGML_SYCL_RMS_BACK_FAST
554+
sum_xx += xv * xv;
555+
sum_xg += xv * gv;
556+
#else
557+
float y1 = xv * xv - c_xx;
558+
float t1 = sum_xx + y1;
559+
c_xx = (t1 - sum_xx) - y1;
560+
sum_xx = t1;
561+
562+
float y2 = xv * gv - c_xg;
563+
float t2 = sum_xg + y2;
564+
c_xg = (t2 - sum_xg) - y2;
565+
sum_xg = t2;
566+
#endif
567+
}
568+
569+
// warp-level reduction
570+
sycl::float2 xx = sycl::float2(sum_xx,
571+
#ifndef GGML_SYCL_RMS_BACK_FAST
572+
c_xx
573+
#else
574+
0.f
575+
#endif
576+
);
577+
sycl::float2 xg = sycl::float2(sum_xg,
578+
#ifndef GGML_SYCL_RMS_BACK_FAST
579+
c_xg
580+
#else
581+
0.f
582+
#endif
583+
);
584+
xx = warp_reduce_sum(xx, item_ct1);
585+
xg = warp_reduce_sum(xg, item_ct1);
586+
587+
// cross-warp reduction using local memory (single barrier)
588+
const auto sub_group = item_ct1.get_sub_group();
589+
const auto sg_id = sub_group.get_group_linear_id();
590+
const auto wi_in_sg = sub_group.get_local_linear_id();
591+
const int nthreads = item_ct1.get_local_range(2);
592+
const int nwarps = nthreads / WARP_SIZE;
593+
594+
sycl::float2 xx_total = xx;
595+
sycl::float2 xg_total = xg;
596+
if (nwarps > 1) {
597+
if (wi_in_sg == 0) {
598+
l_xx[sg_id] = xx;
599+
l_xg[sg_id] = xg;
600+
}
601+
item_ct1.barrier(sycl::access::fence_space::local_space);
602+
603+
if (sg_id == 0) {
604+
const unsigned wi_u = wi_in_sg;
605+
sycl::float2 xx_first = (wi_u < static_cast<unsigned>(nwarps)) ? l_xx[wi_u] : sycl::float2(0.f, 0.f);
606+
sycl::float2 xg_first = (wi_u < static_cast<unsigned>(nwarps)) ? l_xg[wi_u] : sycl::float2(0.f, 0.f);
607+
xx_total = warp_reduce_sum(xx_first, item_ct1);
608+
xg_total = warp_reduce_sum(xg_first, item_ct1);
609+
} else {
610+
// other subgroups keep their local totals; they'll be ignored
611+
xx_total = xx;
612+
xg_total = xg;
613+
}
614+
// ensure all threads see the first-subgroup result via broadcast below
615+
}
616+
617+
// compute inv_r and coeff once per row and broadcast to the whole work-group
618+
float inv_r = 0.f;
619+
float coeff = 0.f;
620+
if (tid == 0) {
621+
const float sum_xx_f = xx_total.x() + xx_total.y();
622+
const float sum_xdz_f = xg_total.x() + xg_total.y();
623+
const float mean_eps = sum_xx_f / (float) D + eps;
624+
const float sum_eps = sum_xx_f + eps * (float) D;
625+
inv_r = sycl::rsqrt(mean_eps);
626+
coeff = -sum_xdz_f / sum_eps;
627+
}
628+
inv_r = sycl::group_broadcast(item_ct1.get_group(), inv_r);
629+
coeff = sycl::group_broadcast(item_ct1.get_group(), coeff);
630+
631+
for (int64_t col = tid; col < D; col += WG) {
632+
d_row[col] = (g_row[col] + coeff * x_row[col]) * inv_r;
633+
}
634+
});
635+
});
636+
637+
}
638+
483639
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
484640

485641
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);

0 commit comments

Comments
 (0)