@@ -420,9 +420,9 @@ struct ggml_backend_opencl_context {
420
420
cl_kernel kernel_clamp;
421
421
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick,
422
422
kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
423
- cl_kernel kernel_norm;
423
+ cl_kernel kernel_norm, kernel_norm_mul_add ;
424
424
cl_kernel kernel_rms_norm, kernel_rms_norm_mul;
425
- cl_kernel kernel_group_norm;
425
+ cl_kernel kernel_group_norm, kernel_group_norm_mul_add ;
426
426
cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
427
427
cl_kernel kernel_soft_max, kernel_soft_max_4;
428
428
cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16;
@@ -1161,7 +1161,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
1161
1161
backend_ctx->program_norm =
1162
1162
build_program_from_source (backend_ctx->context , backend_ctx->device , kernel_src.c_str (), compile_opts);
1163
1163
1164
- CL_CHECK ((backend_ctx->kernel_norm = clCreateKernel (backend_ctx->program_norm , " kernel_norm" , &err), err));
1164
+ CL_CHECK ((backend_ctx->kernel_norm = clCreateKernel (backend_ctx->program_norm , " kernel_norm" , &err), err));
1165
+ CL_CHECK ((backend_ctx->kernel_norm_mul_add = clCreateKernel (backend_ctx->program_norm , " kernel_norm_mul_add" , &err), err));
1165
1166
GGML_LOG_CONT (" ." );
1166
1167
}
1167
1168
@@ -1487,7 +1488,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
1487
1488
backend_ctx->program_group_norm =
1488
1489
build_program_from_source (backend_ctx->context , backend_ctx->device , kernel_src.c_str (), compile_opts);
1489
1490
1490
- CL_CHECK ((backend_ctx->kernel_group_norm = clCreateKernel (backend_ctx->program_group_norm , " kernel_group_norm" , &err), err));
1491
+ CL_CHECK ((backend_ctx->kernel_group_norm = clCreateKernel (backend_ctx->program_group_norm , " kernel_group_norm" , &err), err));
1492
+ CL_CHECK ((backend_ctx->kernel_group_norm_mul_add = clCreateKernel (backend_ctx->program_group_norm , " kernel_group_norm_mul_add" , &err), err));
1491
1493
GGML_LOG_CONT (" ." );
1492
1494
}
1493
1495
@@ -2498,12 +2500,47 @@ static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
2498
2500
if (!ggml_is_contiguous_rows (mul->src [0 ]) || !ggml_is_contiguous_rows (mul->src [1 ])) {
2499
2501
return false ;
2500
2502
}
2503
+ } else if (ops.size () == 3 && ops.begin ()[0 ] == GGML_OP_NORM && ops.begin ()[1 ] == GGML_OP_MUL && ops.begin ()[2 ] == GGML_OP_ADD) {
2504
+ const ggml_tensor *norm = cgraph->nodes [node_idx];
2505
+ const ggml_tensor *mul = cgraph->nodes [node_idx+1 ];
2506
+ const ggml_tensor *add = cgraph->nodes [node_idx+2 ];
2507
+ const ggml_tensor *w = mul->src [0 ] == norm ? mul->src [1 ] : mul->src [0 ];
2508
+ const ggml_tensor *b = add->src [0 ] == mul ? add->src [1 ] : add->src [0 ];
2509
+
2510
+ // norm fusion only supports F32
2511
+ if (norm->src [0 ]->type != GGML_TYPE_F32 || w->type != GGML_TYPE_F32 || b->type != GGML_TYPE_F32) {
2512
+ return false ;
2513
+ }
2514
+
2515
+ if (norm->src [0 ]->ne [0 ] % 4 != 0 ) {
2516
+ return false ;
2517
+ }
2518
+
2519
+ if (!ggml_is_contiguous (norm->src [0 ]) || !ggml_is_contiguous (w) || !ggml_is_contiguous (b)) {
2520
+ return false ;
2521
+ }
2522
+ } else if (ops.size () == 3 && ops.begin ()[0 ] == GGML_OP_GROUP_NORM && ops.begin ()[1 ] == GGML_OP_MUL && ops.begin ()[2 ] == GGML_OP_ADD) {
2523
+ const ggml_tensor *gn = cgraph->nodes [node_idx];
2524
+ const ggml_tensor *mul = cgraph->nodes [node_idx+1 ];
2525
+ const ggml_tensor *add = cgraph->nodes [node_idx+2 ];
2526
+ const ggml_tensor *w = mul->src [0 ] == gn ? mul->src [1 ] : mul->src [0 ];
2527
+ const ggml_tensor *b = add->src [0 ] == mul ? add->src [1 ] : add->src [0 ];
2528
+
2529
+ if (gn->src [0 ]->type != GGML_TYPE_F32 || w->type != GGML_TYPE_F32 || b->type != GGML_TYPE_F32) {
2530
+ return false ;
2531
+ }
2532
+
2533
+ if (!ggml_is_contiguous (gn->src [0 ]) || !ggml_is_contiguous (w) || !ggml_is_contiguous (b)) {
2534
+ return false ;
2535
+ }
2501
2536
}
2502
2537
2503
2538
return true ;
2504
2539
}
2505
2540
2506
2541
static void ggml_opencl_op_rms_norm_fused (ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor);
2542
+ static void ggml_opencl_op_norm_fused (ggml_backend_t backend, ggml_tensor * norm_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor);
2543
+ static void ggml_opencl_op_group_norm_fused (ggml_backend_t backend, ggml_tensor * gn_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor);
2507
2544
2508
2545
static ggml_status ggml_backend_opencl_graph_compute (ggml_backend_t backend, ggml_cgraph * cgraph) {
2509
2546
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context ;
@@ -2520,6 +2557,16 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm
2520
2557
continue ;
2521
2558
}
2522
2559
2560
+ if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse (cgraph, i, { GGML_OP_NORM, GGML_OP_MUL, GGML_OP_ADD })) {
2561
+ ggml_opencl_op_norm_fused (backend, node, cgraph->nodes [i+1 ], cgraph->nodes [i+2 ]);
2562
+ i += 2 ;
2563
+ continue ;
2564
+ }
2565
+ if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse (cgraph, i, { GGML_OP_GROUP_NORM, GGML_OP_MUL, GGML_OP_ADD })) {
2566
+ ggml_opencl_op_group_norm_fused (backend, node, cgraph->nodes [i+1 ], cgraph->nodes [i+2 ]);
2567
+ i += 2 ;
2568
+ continue ;
2569
+ }
2523
2570
if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse (cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
2524
2571
ggml_opencl_op_rms_norm_fused (backend, node, cgraph->nodes [i+1 ]);
2525
2572
i++;
@@ -5039,6 +5086,140 @@ static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor *
5039
5086
backend_ctx->enqueue_ndrange_kernel (kernel, 3 , global_work_size, local_work_size, dst);
5040
5087
}
5041
5088
5089
+ static void ggml_opencl_op_norm_fused (ggml_backend_t backend, ggml_tensor * norm_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor) {
5090
+ GGML_ASSERT (norm_tensor && mul_tensor && add_tensor);
5091
+
5092
+ const ggml_tensor * src0 = norm_tensor->src [0 ];
5093
+ const ggml_tensor * src1 = mul_tensor->src [0 ] == norm_tensor ? mul_tensor->src [1 ] : mul_tensor->src [0 ];
5094
+ const ggml_tensor * src2 = add_tensor->src [0 ] == mul_tensor ? add_tensor->src [1 ] : add_tensor->src [0 ];
5095
+ const ggml_tensor * dst = add_tensor;
5096
+
5097
+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra ;
5098
+ ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra ;
5099
+ ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra ;
5100
+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra ;
5101
+
5102
+ cl_ulong offset0 = extra0->offset + src0->view_offs ;
5103
+ cl_ulong offset1 = extra1->offset + src1->view_offs ;
5104
+ cl_ulong offset2 = extra2->offset + src2->view_offs ;
5105
+ cl_ulong offsetd = extrad->offset + dst->view_offs ;
5106
+
5107
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context ;
5108
+
5109
+ float eps;
5110
+ memcpy (&eps, norm_tensor->op_params , sizeof (float ));
5111
+
5112
+ const int ne00 = src0->ne [0 ], ne01 = src0->ne [1 ], ne02 = src0->ne [2 ], ne03 = src0->ne [3 ];
5113
+ const cl_ulong nb01 = src0->nb [1 ], nb02 = src0->nb [2 ], nb03 = src0->nb [3 ];
5114
+ const int ne10 = src1->ne [0 ], ne11 = src1->ne [1 ], ne12 = src1->ne [2 ], ne13 = src1->ne [3 ];
5115
+ const cl_ulong nb11 = src1->nb [1 ], nb12 = src1->nb [2 ], nb13 = src1->nb [3 ];
5116
+ const int ne20 = src2->ne [0 ], ne21 = src2->ne [1 ], ne22 = src2->ne [2 ], ne23 = src2->ne [3 ];
5117
+ const cl_ulong nb21 = src2->nb [1 ], nb22 = src2->nb [2 ], nb23 = src2->nb [3 ];
5118
+ const cl_ulong nbd1 = dst->nb [1 ], nbd2 = dst->nb [2 ], nbd3 = dst->nb [3 ];
5119
+
5120
+ size_t sgs;
5121
+ if (backend_ctx->gpu_family == ADRENO) sgs = 64 ;
5122
+ else if (backend_ctx->gpu_family == INTEL) sgs = 32 ;
5123
+ else GGML_ASSERT (false && " Unsupported GPU" );
5124
+
5125
+ cl_kernel kernel = backend_ctx->kernel_norm_mul_add ;
5126
+
5127
+ int nth = sgs;
5128
+ int max_workgroup_size = backend_ctx->get_kernel_workgroup_size (kernel);
5129
+ while (nth < ne00/4 && nth < max_workgroup_size) nth *= 2 ;
5130
+ nth = MIN (nth, max_workgroup_size);
5131
+ nth = MIN (nth, ne00/4 );
5132
+
5133
+ size_t gws[] = {(size_t )ne01*nth, (size_t )ne02, (size_t )ne03};
5134
+ size_t lws[] = {(size_t )nth, 1 , 1 };
5135
+ size_t num_subgroups = (nth + sgs - 1 ) / sgs;
5136
+
5137
+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
5138
+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
5139
+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra1->data_device ));
5140
+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
5141
+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extra2->data_device ));
5142
+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offset2));
5143
+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (cl_mem), &extrad->data_device ));
5144
+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (cl_ulong), &offsetd));
5145
+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &ne00));
5146
+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (int ), &ne01));
5147
+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (int ), &ne02));
5148
+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (int ), &ne03));
5149
+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (cl_ulong), &nb01));
5150
+ CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (cl_ulong), &nb02));
5151
+ CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (cl_ulong), &nb03));
5152
+ CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (int ), &ne10));
5153
+ CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (int ), &ne11));
5154
+ CL_CHECK (clSetKernelArg (kernel, 17 , sizeof (int ), &ne12));
5155
+ CL_CHECK (clSetKernelArg (kernel, 18 , sizeof (int ), &ne13));
5156
+ CL_CHECK (clSetKernelArg (kernel, 19 , sizeof (cl_ulong), &nb11));
5157
+ CL_CHECK (clSetKernelArg (kernel, 20 , sizeof (cl_ulong), &nb12));
5158
+ CL_CHECK (clSetKernelArg (kernel, 21 , sizeof (cl_ulong), &nb13));
5159
+ CL_CHECK (clSetKernelArg (kernel, 22 , sizeof (int ), &ne20));
5160
+ CL_CHECK (clSetKernelArg (kernel, 23 , sizeof (int ), &ne21));
5161
+ CL_CHECK (clSetKernelArg (kernel, 24 , sizeof (int ), &ne22));
5162
+ CL_CHECK (clSetKernelArg (kernel, 25 , sizeof (int ), &ne23));
5163
+ CL_CHECK (clSetKernelArg (kernel, 26 , sizeof (cl_ulong), &nb21));
5164
+ CL_CHECK (clSetKernelArg (kernel, 27 , sizeof (cl_ulong), &nb22));
5165
+ CL_CHECK (clSetKernelArg (kernel, 28 , sizeof (cl_ulong), &nb23));
5166
+ CL_CHECK (clSetKernelArg (kernel, 29 , sizeof (cl_ulong), &nbd1));
5167
+ CL_CHECK (clSetKernelArg (kernel, 30 , sizeof (cl_ulong), &nbd2));
5168
+ CL_CHECK (clSetKernelArg (kernel, 31 , sizeof (cl_ulong), &nbd3));
5169
+ CL_CHECK (clSetKernelArg (kernel, 32 , sizeof (float ), &eps));
5170
+ CL_CHECK (clSetKernelArg (kernel, 33 , sizeof (cl_float2) * num_subgroups, NULL ));
5171
+
5172
+ backend_ctx->enqueue_ndrange_kernel (kernel, 3 , gws, lws, dst);
5173
+ }
5174
+
5175
+ static void ggml_opencl_op_group_norm_fused (ggml_backend_t backend, ggml_tensor * gn_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor) {
5176
+ GGML_ASSERT (gn_tensor && mul_tensor && add_tensor);
5177
+
5178
+ const ggml_tensor * src0 = gn_tensor->src [0 ];
5179
+ const ggml_tensor * src1 = mul_tensor->src [0 ] == gn_tensor ? mul_tensor->src [1 ] : mul_tensor->src [0 ];
5180
+ const ggml_tensor * src2 = add_tensor->src [0 ] == mul_tensor ? add_tensor->src [1 ] : add_tensor->src [0 ];
5181
+ const ggml_tensor * dst = add_tensor;
5182
+
5183
+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra ;
5184
+ ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra ;
5185
+ ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra ;
5186
+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra ;
5187
+
5188
+ cl_ulong offset0 = extra0->offset + src0->view_offs ;
5189
+ cl_ulong offset1 = extra1->offset + src1->view_offs ;
5190
+ cl_ulong offset2 = extra2->offset + src2->view_offs ;
5191
+ cl_ulong offsetd = extrad->offset + dst->view_offs ;
5192
+
5193
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context ;
5194
+
5195
+ int groups;
5196
+ float eps;
5197
+ memcpy (&groups, gn_tensor->op_params , sizeof (int ));
5198
+ memcpy (&eps, (char *)gn_tensor->op_params + sizeof (int ), sizeof (float ));
5199
+
5200
+ cl_kernel kernel = backend_ctx->kernel_group_norm_mul_add ;
5201
+ int max_workgroup_size = backend_ctx->get_kernel_workgroup_size (kernel);
5202
+ int ne = ggml_nelements (src0);
5203
+ int group_size = ne / groups;
5204
+
5205
+ size_t lws[] = { (size_t )MIN (max_workgroup_size, group_size) };
5206
+ size_t gws[] = { (size_t )groups * lws[0 ] };
5207
+
5208
+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
5209
+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
5210
+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra1->data_device ));
5211
+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
5212
+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extra2->data_device ));
5213
+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offset2));
5214
+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (cl_mem), &extrad->data_device ));
5215
+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (cl_ulong), &offsetd));
5216
+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &ne));
5217
+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (int ), &group_size));
5218
+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (float ), &eps));
5219
+
5220
+ backend_ctx->enqueue_ndrange_kernel (kernel, 1 , gws, lws, dst);
5221
+ }
5222
+
5042
5223
static void ggml_cl_group_norm (ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
5043
5224
GGML_ASSERT (src0);
5044
5225
GGML_ASSERT (src0->extra );
0 commit comments