@@ -333,6 +333,7 @@ struct ggml_backend_opencl_context {
333333 size_t max_alloc_size;
334334 bool fp16_support;
335335 bool has_vector_subgroup_broadcast;
336+ bool disable_fusion;
336337 ggml_cl_compiler_version adreno_cl_compiler_version;
337338
338339 int adreno_wave_size;
@@ -411,7 +412,7 @@ struct ggml_backend_opencl_context {
411412 cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_geglu_erf, kernel_geglu_quick,
412413 kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
413414 cl_kernel kernel_norm;
414- cl_kernel kernel_rms_norm;
415+ cl_kernel kernel_rms_norm, kernel_rms_norm_mul ;
415416 cl_kernel kernel_group_norm;
416417 cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
417418 cl_kernel kernel_soft_max, kernel_soft_max_4;
@@ -1100,7 +1101,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
11001101 backend_ctx->program_rms_norm =
11011102 build_program_from_source (backend_ctx->context , backend_ctx->device , kernel_src.c_str (), compile_opts);
11021103
1103- CL_CHECK ((backend_ctx->kernel_rms_norm = clCreateKernel (backend_ctx->program_rms_norm , " kernel_rms_norm" , &err), err));
1104+ CL_CHECK ((backend_ctx->kernel_rms_norm = clCreateKernel (backend_ctx->program_rms_norm , " kernel_rms_norm" , &err), err));
1105+ CL_CHECK ((backend_ctx->kernel_rms_norm_mul = clCreateKernel (backend_ctx->program_rms_norm , " kernel_rms_norm_mul" , &err), err));
11041106 GGML_LOG_CONT (" ." );
11051107 }
11061108
@@ -2110,6 +2112,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
21102112 CL_CHECK ((backend_ctx->B_d_max = clCreateBuffer (context, 0 , max_B_d_bytes, NULL , &err), err));
21112113#endif // GGML_OPENCL_USE_ADRENO_KERNELS
21122114
2115+ backend_ctx->disable_fusion = getenv (" GGML_OPENCL_DISABLE_FUSION" ) != nullptr ;
2116+
21132117 dev_ctx->backend_ctx = backend_ctx.release ();
21142118 return dev_ctx->backend_ctx ;
21152119}
@@ -2279,7 +2283,45 @@ static void sync_with_other_backends(ggml_backend_t backend) {
22792283 sync_with_other_backends (backend_ctx);
22802284}
22812285
2286+ static bool ggml_opencl_can_fuse (const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
2287+ if (!ggml_can_fuse (cgraph, node_idx, ops)) {
2288+ return false ;
2289+ }
2290+
2291+ if (ops.size () == 2 && ops.begin ()[0 ] == GGML_OP_RMS_NORM && ops.begin ()[1 ] == GGML_OP_MUL) {
2292+ const ggml_tensor *rms_norm = cgraph->nodes [node_idx];
2293+ const ggml_tensor *mul = cgraph->nodes [node_idx+1 ];
2294+
2295+ GGML_ASSERT (rms_norm->src [0 ]->type == GGML_TYPE_F32);
2296+ GGML_ASSERT (rms_norm->type == GGML_TYPE_F32);
2297+
2298+ // rms_norm only supports f32
2299+ if (mul->src [0 ]->type != GGML_TYPE_F32 ||
2300+ mul->src [1 ]->type != GGML_TYPE_F32 ||
2301+ mul->type != GGML_TYPE_F32) {
2302+ return false ;
2303+ }
2304+
2305+ // if rms_norm is the B operand, then we don't handle broadcast
2306+ if (rms_norm == mul->src [1 ] &&
2307+ !ggml_are_same_shape (mul->src [0 ], rms_norm->src [1 ])) {
2308+ return false ;
2309+ }
2310+
2311+ // rms_norm assumes contiguous rows
2312+ if (!ggml_is_contiguous_rows (mul->src [0 ]) || !ggml_is_contiguous_rows (mul->src [1 ])) {
2313+ return false ;
2314+ }
2315+ }
2316+
2317+ return true ;
2318+ }
2319+
2320+ static void ggml_opencl_op_rms_norm_fused (ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor);
2321+
22822322static ggml_status ggml_backend_opencl_graph_compute (ggml_backend_t backend, ggml_cgraph * cgraph) {
2323+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context ;
2324+
22832325 for (int i = 0 ; i < cgraph->n_nodes ; i++) {
22842326 ggml_tensor * node = cgraph->nodes [i];
22852327
@@ -2292,6 +2334,12 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm
22922334 continue ;
22932335 }
22942336
2337+ if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse (cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
2338+ ggml_opencl_op_rms_norm_fused (backend, node, cgraph->nodes [i+1 ]);
2339+ i++;
2340+ continue ;
2341+ }
2342+
22952343 bool ok = ggml_cl_compute_forward (backend, node);
22962344 if (!ok) {
22972345 GGML_LOG_ERROR (" %s: error: op not supported %s (%s)\n " , __func__, node->name , ggml_op_name (node->op ));
@@ -4455,6 +4503,117 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c
44554503 backend_ctx->enqueue_ndrange_kernel (kernel, 3 , global_work_size, local_work_size, dst);
44564504}
44574505
4506+ static void ggml_opencl_op_rms_norm_fused (ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor) {
4507+ GGML_ASSERT (mul_tensor);
4508+ GGML_ASSERT (rms_norm_tensor);
4509+
4510+ // src0 is the src of rms_norm, src1 is the other src of mul (one being rms_norm)
4511+ const ggml_tensor * src0 = rms_norm_tensor->src [0 ];
4512+ const ggml_tensor * src1;
4513+ if (mul_tensor->src [0 ] == rms_norm_tensor) {
4514+ src1 = mul_tensor->src [1 ];
4515+ } else if (mul_tensor->src [1 ] == rms_norm_tensor) {
4516+ src1 = mul_tensor->src [0 ];
4517+ } else {
4518+ GGML_ASSERT (false && " Invalid args for rms_norm and mul" );
4519+ }
4520+ const ggml_tensor * dst = mul_tensor;
4521+
4522+ GGML_ASSERT (src0);
4523+ GGML_ASSERT (src0->extra );
4524+ GGML_ASSERT (src1);
4525+ GGML_ASSERT (src1->extra );
4526+ GGML_ASSERT (dst);
4527+ GGML_ASSERT (dst->extra );
4528+
4529+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra ;
4530+ ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra ;
4531+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra ;
4532+
4533+ cl_ulong offset0 = extra0->offset + src0->view_offs ;
4534+ cl_ulong offset1 = extra1->offset + src0->view_offs ;
4535+ cl_ulong offsetd = extrad->offset + dst->view_offs ;
4536+
4537+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context ;
4538+
4539+ float eps;
4540+ memcpy (&eps, rms_norm_tensor->op_params , sizeof (float ));
4541+
4542+ const int ne00 = src0->ne [0 ];
4543+ const int ne01 = src0->ne [1 ];
4544+ const int ne02 = src0->ne [2 ];
4545+ const int ne03 = src0->ne [3 ];
4546+
4547+ const cl_ulong nb01 = src0->nb [1 ];
4548+ const cl_ulong nb02 = src0->nb [2 ];
4549+ const cl_ulong nb03 = src0->nb [3 ];
4550+
4551+ const int ne10 = src1->ne [0 ];
4552+ const int ne11 = src1->ne [1 ];
4553+ const int ne12 = src1->ne [2 ];
4554+ const int ne13 = src1->ne [3 ];
4555+
4556+ const cl_ulong nb11 = src1->nb [1 ];
4557+ const cl_ulong nb12 = src1->nb [2 ];
4558+ const cl_ulong nb13 = src1->nb [3 ];
4559+
4560+ const cl_ulong nb1 = dst->nb [1 ];
4561+ const cl_ulong nb2 = dst->nb [2 ];
4562+ const cl_ulong nb3 = dst->nb [3 ];
4563+
4564+ GGML_ASSERT (ne00 % 4 == 0 );
4565+
4566+ size_t sgs;
4567+ if (backend_ctx->gpu_family == ADRENO) {
4568+ sgs = 64 ;
4569+ } else if (backend_ctx->gpu_family == INTEL) {
4570+ sgs = 32 ;
4571+ } else {
4572+ GGML_ASSERT (false && " Unsupported GPU" );
4573+ }
4574+
4575+ cl_kernel kernel = backend_ctx->kernel_rms_norm_mul ;
4576+
4577+ int nth = sgs;
4578+ int max_workgroup_size = backend_ctx->get_kernel_workgroup_size (kernel);
4579+ while (nth < ne00 && nth < max_workgroup_size) {
4580+ nth *= 2 ;
4581+ }
4582+ nth = MIN (nth, max_workgroup_size);
4583+ nth = MIN (nth, ne00);
4584+
4585+ size_t global_work_size[] = {(size_t )ne01*nth, (size_t )ne02, (size_t )ne03};
4586+ size_t local_work_size[] = {(size_t )nth, 1 , 1 };
4587+
4588+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
4589+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
4590+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra1->data_device ));
4591+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
4592+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extrad->data_device ));
4593+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
4594+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne00));
4595+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (int ), &ne01));
4596+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &ne02));
4597+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (int ), &ne03));
4598+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (cl_ulong), &nb01));
4599+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (cl_ulong), &nb02));
4600+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (cl_ulong), &nb03));
4601+ CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (int ), &ne10));
4602+ CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (int ), &ne11));
4603+ CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (int ), &ne12));
4604+ CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (int ), &ne13));
4605+ CL_CHECK (clSetKernelArg (kernel, 17 , sizeof (cl_ulong), &nb11));
4606+ CL_CHECK (clSetKernelArg (kernel, 18 , sizeof (cl_ulong), &nb12));
4607+ CL_CHECK (clSetKernelArg (kernel, 19 , sizeof (cl_ulong), &nb13));
4608+ CL_CHECK (clSetKernelArg (kernel, 20 , sizeof (cl_ulong), &nb1));
4609+ CL_CHECK (clSetKernelArg (kernel, 21 , sizeof (cl_ulong), &nb2));
4610+ CL_CHECK (clSetKernelArg (kernel, 22 , sizeof (cl_ulong), &nb3));
4611+ CL_CHECK (clSetKernelArg (kernel, 23 , sizeof (float ), &eps));
4612+ CL_CHECK (clSetKernelArg (kernel, 24 , sizeof (float )*nth/sgs, NULL ));
4613+
4614+ backend_ctx->enqueue_ndrange_kernel (kernel, 3 , global_work_size, local_work_size, dst);
4615+ }
4616+
44584617static void ggml_cl_group_norm (ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
44594618 GGML_ASSERT (src0);
44604619 GGML_ASSERT (src0->extra );
0 commit comments