Skip to content

Commit 86076f9

Browse files
authored
OpenCL: add fused group_norm/norm, mul, add (ggml-org#15314)
* add fused group_norm/norm, mul, add * fix spacing * revert rms_norm logic * fix trailing whitespace
1 parent bcbddcd commit 86076f9

File tree

4 files changed

+399
-4
lines changed

4 files changed

+399
-4
lines changed

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 185 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -420,9 +420,9 @@ struct ggml_backend_opencl_context {
420420
cl_kernel kernel_clamp;
421421
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick,
422422
kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
423-
cl_kernel kernel_norm;
423+
cl_kernel kernel_norm, kernel_norm_mul_add;
424424
cl_kernel kernel_rms_norm, kernel_rms_norm_mul;
425-
cl_kernel kernel_group_norm;
425+
cl_kernel kernel_group_norm, kernel_group_norm_mul_add;
426426
cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
427427
cl_kernel kernel_soft_max, kernel_soft_max_4;
428428
cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16;
@@ -1161,7 +1161,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
11611161
backend_ctx->program_norm =
11621162
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
11631163

1164-
CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program_norm, "kernel_norm", &err), err));
1164+
CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program_norm, "kernel_norm", &err), err));
1165+
CL_CHECK((backend_ctx->kernel_norm_mul_add = clCreateKernel(backend_ctx->program_norm, "kernel_norm_mul_add", &err), err));
11651166
GGML_LOG_CONT(".");
11661167
}
11671168

@@ -1487,7 +1488,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
14871488
backend_ctx->program_group_norm =
14881489
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
14891490

1490-
CL_CHECK((backend_ctx->kernel_group_norm = clCreateKernel(backend_ctx->program_group_norm, "kernel_group_norm", &err), err));
1491+
CL_CHECK((backend_ctx->kernel_group_norm = clCreateKernel(backend_ctx->program_group_norm, "kernel_group_norm", &err), err));
1492+
CL_CHECK((backend_ctx->kernel_group_norm_mul_add = clCreateKernel(backend_ctx->program_group_norm, "kernel_group_norm_mul_add", &err), err));
14911493
GGML_LOG_CONT(".");
14921494
}
14931495

@@ -2498,12 +2500,47 @@ static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
24982500
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
24992501
return false;
25002502
}
2503+
} else if (ops.size() == 3 && ops.begin()[0] == GGML_OP_NORM && ops.begin()[1] == GGML_OP_MUL && ops.begin()[2] == GGML_OP_ADD) {
2504+
const ggml_tensor *norm = cgraph->nodes[node_idx];
2505+
const ggml_tensor *mul = cgraph->nodes[node_idx+1];
2506+
const ggml_tensor *add = cgraph->nodes[node_idx+2];
2507+
const ggml_tensor *w = mul->src[0] == norm ? mul->src[1] : mul->src[0];
2508+
const ggml_tensor *b = add->src[0] == mul ? add->src[1] : add->src[0];
2509+
2510+
// norm fusion only supports F32
2511+
if (norm->src[0]->type != GGML_TYPE_F32 || w->type != GGML_TYPE_F32 || b->type != GGML_TYPE_F32) {
2512+
return false;
2513+
}
2514+
2515+
if (norm->src[0]->ne[0] % 4 != 0) {
2516+
return false;
2517+
}
2518+
2519+
if (!ggml_is_contiguous(norm->src[0]) || !ggml_is_contiguous(w) || !ggml_is_contiguous(b)) {
2520+
return false;
2521+
}
2522+
} else if (ops.size() == 3 && ops.begin()[0] == GGML_OP_GROUP_NORM && ops.begin()[1] == GGML_OP_MUL && ops.begin()[2] == GGML_OP_ADD) {
2523+
const ggml_tensor *gn = cgraph->nodes[node_idx];
2524+
const ggml_tensor *mul = cgraph->nodes[node_idx+1];
2525+
const ggml_tensor *add = cgraph->nodes[node_idx+2];
2526+
const ggml_tensor *w = mul->src[0] == gn ? mul->src[1] : mul->src[0];
2527+
const ggml_tensor *b = add->src[0] == mul ? add->src[1] : add->src[0];
2528+
2529+
if (gn->src[0]->type != GGML_TYPE_F32 || w->type != GGML_TYPE_F32 || b->type != GGML_TYPE_F32) {
2530+
return false;
2531+
}
2532+
2533+
if (!ggml_is_contiguous(gn->src[0]) || !ggml_is_contiguous(w) || !ggml_is_contiguous(b)) {
2534+
return false;
2535+
}
25012536
}
25022537

25032538
return true;
25042539
}
25052540

25062541
static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor);
2542+
static void ggml_opencl_op_norm_fused(ggml_backend_t backend, ggml_tensor * norm_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor);
2543+
static void ggml_opencl_op_group_norm_fused(ggml_backend_t backend, ggml_tensor * gn_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor);
25072544

25082545
static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
25092546
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
@@ -2520,6 +2557,16 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm
25202557
continue;
25212558
}
25222559

2560+
if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_NORM, GGML_OP_MUL, GGML_OP_ADD })) {
2561+
ggml_opencl_op_norm_fused(backend, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
2562+
i += 2;
2563+
continue;
2564+
}
2565+
if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_GROUP_NORM, GGML_OP_MUL, GGML_OP_ADD })) {
2566+
ggml_opencl_op_group_norm_fused(backend, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
2567+
i += 2;
2568+
continue;
2569+
}
25232570
if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
25242571
ggml_opencl_op_rms_norm_fused(backend, node, cgraph->nodes[i+1]);
25252572
i++;
@@ -5039,6 +5086,140 @@ static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor *
50395086
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
50405087
}
50415088

5089+
static void ggml_opencl_op_norm_fused(ggml_backend_t backend, ggml_tensor * norm_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor) {
5090+
GGML_ASSERT(norm_tensor && mul_tensor && add_tensor);
5091+
5092+
const ggml_tensor * src0 = norm_tensor->src[0];
5093+
const ggml_tensor * src1 = mul_tensor->src[0] == norm_tensor ? mul_tensor->src[1] : mul_tensor->src[0];
5094+
const ggml_tensor * src2 = add_tensor->src[0] == mul_tensor ? add_tensor->src[1] : add_tensor->src[0];
5095+
const ggml_tensor * dst = add_tensor;
5096+
5097+
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
5098+
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
5099+
ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra;
5100+
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
5101+
5102+
cl_ulong offset0 = extra0->offset + src0->view_offs;
5103+
cl_ulong offset1 = extra1->offset + src1->view_offs;
5104+
cl_ulong offset2 = extra2->offset + src2->view_offs;
5105+
cl_ulong offsetd = extrad->offset + dst->view_offs;
5106+
5107+
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
5108+
5109+
float eps;
5110+
memcpy(&eps, norm_tensor->op_params, sizeof(float));
5111+
5112+
const int ne00 = src0->ne[0], ne01 = src0->ne[1], ne02 = src0->ne[2], ne03 = src0->ne[3];
5113+
const cl_ulong nb01 = src0->nb[1], nb02 = src0->nb[2], nb03 = src0->nb[3];
5114+
const int ne10 = src1->ne[0], ne11 = src1->ne[1], ne12 = src1->ne[2], ne13 = src1->ne[3];
5115+
const cl_ulong nb11 = src1->nb[1], nb12 = src1->nb[2], nb13 = src1->nb[3];
5116+
const int ne20 = src2->ne[0], ne21 = src2->ne[1], ne22 = src2->ne[2], ne23 = src2->ne[3];
5117+
const cl_ulong nb21 = src2->nb[1], nb22 = src2->nb[2], nb23 = src2->nb[3];
5118+
const cl_ulong nbd1 = dst->nb[1], nbd2 = dst->nb[2], nbd3 = dst->nb[3];
5119+
5120+
size_t sgs;
5121+
if (backend_ctx->gpu_family == ADRENO) sgs = 64;
5122+
else if (backend_ctx->gpu_family == INTEL) sgs = 32;
5123+
else GGML_ASSERT(false && "Unsupported GPU");
5124+
5125+
cl_kernel kernel = backend_ctx->kernel_norm_mul_add;
5126+
5127+
int nth = sgs;
5128+
int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);
5129+
while (nth < ne00/4 && nth < max_workgroup_size) nth *= 2;
5130+
nth = MIN(nth, max_workgroup_size);
5131+
nth = MIN(nth, ne00/4);
5132+
5133+
size_t gws[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
5134+
size_t lws[] = {(size_t)nth, 1, 1};
5135+
size_t num_subgroups = (nth + sgs - 1) / sgs;
5136+
5137+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
5138+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
5139+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
5140+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
5141+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device));
5142+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
5143+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
5144+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
5145+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
5146+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01));
5147+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02));
5148+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne03));
5149+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb01));
5150+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb02));
5151+
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb03));
5152+
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne10));
5153+
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne11));
5154+
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne12));
5155+
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne13));
5156+
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
5157+
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
5158+
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
5159+
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne20));
5160+
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne21));
5161+
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne22));
5162+
CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne23));
5163+
CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb21));
5164+
CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb22));
5165+
CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb23));
5166+
CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nbd1));
5167+
CL_CHECK(clSetKernelArg(kernel, 30, sizeof(cl_ulong), &nbd2));
5168+
CL_CHECK(clSetKernelArg(kernel, 31, sizeof(cl_ulong), &nbd3));
5169+
CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float), &eps));
5170+
CL_CHECK(clSetKernelArg(kernel, 33, sizeof(cl_float2) * num_subgroups, NULL));
5171+
5172+
backend_ctx->enqueue_ndrange_kernel(kernel, 3, gws, lws, dst);
5173+
}
5174+
5175+
static void ggml_opencl_op_group_norm_fused(ggml_backend_t backend, ggml_tensor * gn_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor) {
5176+
GGML_ASSERT(gn_tensor && mul_tensor && add_tensor);
5177+
5178+
const ggml_tensor * src0 = gn_tensor->src[0];
5179+
const ggml_tensor * src1 = mul_tensor->src[0] == gn_tensor ? mul_tensor->src[1] : mul_tensor->src[0];
5180+
const ggml_tensor * src2 = add_tensor->src[0] == mul_tensor ? add_tensor->src[1] : add_tensor->src[0];
5181+
const ggml_tensor * dst = add_tensor;
5182+
5183+
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
5184+
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
5185+
ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra;
5186+
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
5187+
5188+
cl_ulong offset0 = extra0->offset + src0->view_offs;
5189+
cl_ulong offset1 = extra1->offset + src1->view_offs;
5190+
cl_ulong offset2 = extra2->offset + src2->view_offs;
5191+
cl_ulong offsetd = extrad->offset + dst->view_offs;
5192+
5193+
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
5194+
5195+
int groups;
5196+
float eps;
5197+
memcpy(&groups, gn_tensor->op_params, sizeof(int));
5198+
memcpy(&eps, (char *)gn_tensor->op_params + sizeof(int), sizeof(float));
5199+
5200+
cl_kernel kernel = backend_ctx->kernel_group_norm_mul_add;
5201+
int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);
5202+
int ne = ggml_nelements(src0);
5203+
int group_size = ne / groups;
5204+
5205+
size_t lws[] = { (size_t)MIN(max_workgroup_size, group_size) };
5206+
size_t gws[] = { (size_t)groups * lws[0] };
5207+
5208+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
5209+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
5210+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
5211+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
5212+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device));
5213+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
5214+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
5215+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
5216+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne));
5217+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &group_size));
5218+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(float), &eps));
5219+
5220+
backend_ctx->enqueue_ndrange_kernel(kernel, 1, gws, lws, dst);
5221+
}
5222+
50425223
static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
50435224
GGML_ASSERT(src0);
50445225
GGML_ASSERT(src0->extra);

ggml/src/ggml-opencl/kernels/group_norm.cl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,52 @@ kernel void kernel_group_norm(
7070
dst[j] *= scale;
7171
}
7272
}
73+
74+
//------------------------------------------------------------------------------
75+
// group_norm_mul_add
76+
//------------------------------------------------------------------------------
77+
#ifdef INTEL_GPU
78+
REQD_SUBGROUP_SIZE_32
79+
#elif defined (ADRENO_GPU)
80+
REQD_SUBGROUP_SIZE_64
81+
#endif
82+
kernel void kernel_group_norm_mul_add(
83+
global float * src0, ulong offset0,
84+
global float * src1, ulong offset1,
85+
global float * src2, ulong offset2,
86+
global float * dst, ulong offsetd,
87+
int ne,
88+
int group_size,
89+
float eps
90+
) {
91+
src0 = (global float *)((global char *)src0 + offset0);
92+
src1 = (global float *)((global char *)src1 + offset1);
93+
src2 = (global float *)((global char *)src2 + offset2);
94+
dst = (global float *)((global char *)dst + offsetd);
95+
96+
int start = get_group_id(0) * group_size;
97+
int end = start + group_size;
98+
if (end > ne) {
99+
end = ne;
100+
}
101+
102+
float sum = 0.0f;
103+
float sum_sq = 0.0f;
104+
105+
for (int j = start + get_local_id(0); j < end; j += get_local_size(0)) {
106+
float val = src0[j];
107+
sum += val;
108+
sum_sq += val*val;
109+
}
110+
111+
sum = sub_group_reduce_add(sum);
112+
sum_sq = sub_group_reduce_add(sum_sq);
113+
114+
const float mean = sum / group_size;
115+
const float var = sum_sq / group_size - mean * mean;
116+
const float scale = rsqrt(var + eps);
117+
118+
for (int j = start + get_local_id(0); j < end; j += get_local_size(0)) {
119+
dst[j] = ((src0[j] - mean) * scale) * src1[j] + src2[j];
120+
}
121+
}

ggml/src/ggml-opencl/kernels/norm.cl

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,83 @@ kernel void kernel_norm(
7979
y[i00] = y[i00] * scale;
8080
}
8181
}
82+
83+
//------------------------------------------------------------------------------
84+
// norm_mul_add
85+
//------------------------------------------------------------------------------
86+
#ifdef INTEL_GPU
87+
REQD_SUBGROUP_SIZE_32
88+
#elif defined (ADRENO_GPU)
89+
REQD_SUBGROUP_SIZE_64
90+
#endif
91+
kernel void kernel_norm_mul_add(
92+
global char * src0_ptr, ulong src0_offset,
93+
global char * src1_ptr, ulong src1_offset,
94+
global char * src2_ptr, ulong src2_offset,
95+
global char * dst_ptr, ulong dst_offset,
96+
int ne00, int ne01, int ne02, int ne03,
97+
ulong nb01, ulong nb02, ulong nb03,
98+
int ne10, int ne11, int ne12, int ne13,
99+
ulong nb11, ulong nb12, ulong nb13,
100+
int ne20, int ne21, int ne22, int ne23,
101+
ulong nb21, ulong nb22, ulong nb23,
102+
ulong nbd1, ulong nbd2, ulong nbd3,
103+
float eps,
104+
local float2 * sums
105+
) {
106+
const int i03 = get_group_id(2);
107+
const int i02 = get_group_id(1);
108+
const int i01 = get_group_id(0);
109+
110+
global float4 * x = (global float4 *)(src0_ptr + src0_offset + i01*nb01 + i02*nb02 + i03*nb03);
111+
global float4 * w = (global float4 *)(src1_ptr + src1_offset + (i01%ne11)*nb11 + (i02%ne12)*nb12 + (i03%ne13)*nb13);
112+
global float4 * b = (global float4 *)(src2_ptr + src2_offset + (i01%ne21)*nb21 + (i02%ne22)*nb22 + (i03%ne23)*nb23);
113+
global float4 * y = (global float4 *)(dst_ptr + dst_offset + i01*nbd1 + i02*nbd2 + i03*nbd3);
114+
115+
float p_sum = 0.0f;
116+
float p_sum_sq = 0.0f;
117+
118+
const int n_chunks = ne00 / 4;
119+
for (int i00 = get_local_id(0); i00 < n_chunks; i00 += get_local_size(0)) {
120+
float4 val = x[i00];
121+
p_sum += val.x + val.y + val.z + val.w;
122+
p_sum_sq += dot(val, val);
123+
}
124+
125+
p_sum = sub_group_reduce_add(p_sum);
126+
p_sum_sq = sub_group_reduce_add(p_sum_sq);
127+
128+
if (get_sub_group_local_id() == 0) {
129+
sums[get_sub_group_id()] = (float2)(p_sum, p_sum_sq);
130+
}
131+
barrier(CLK_LOCAL_MEM_FENCE);
132+
133+
if (get_local_id(0) == 0) {
134+
float sum = 0.0f;
135+
float sum_sq = 0.0f;
136+
for (uint i = 0; i < get_num_sub_groups(); ++i) {
137+
float2 s = sums[i];
138+
sum += s.x;
139+
sum_sq += s.y;
140+
}
141+
142+
const float inv_ne00 = 1.0f / (float)ne00;
143+
const float mean = sum * inv_ne00;
144+
const float variance = mad(-mean, mean, sum_sq * inv_ne00);
145+
146+
sums[0] = (float2)(mean, rsqrt(variance + eps));
147+
}
148+
barrier(CLK_LOCAL_MEM_FENCE);
149+
150+
const float2 mean_scale = sums[0];
151+
const float mean = mean_scale.x;
152+
const float scale = mean_scale.y;
153+
const float neg_mean_scale = -mean * scale;
154+
155+
for (int i00 = get_local_id(0); i00 < n_chunks; i00 += get_local_size(0)) {
156+
const int w_idx = ne10 > 1 ? i00 : 0;
157+
const int b_idx = ne20 > 1 ? i00 : 0;
158+
const float4 norm_x = mad(x[i00], (float4)scale, (float4)neg_mean_scale);
159+
y[i00] = mad(norm_x, w[w_idx], b[b_idx]);
160+
}
161+
}

0 commit comments

Comments
 (0)