@@ -1007,17 +1007,18 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
10071007 case GGML_OP_ADD:
10081008 case GGML_OP_SCALE:
10091009 case GGML_OP_MUL:
1010- return true ;
1010+ return op-> src [ 0 ]-> type == GGML_TYPE_F32 ;
10111011 case GGML_OP_UNARY:
10121012 switch (ggml_get_unary_op (op)) {
10131013 case GGML_UNARY_OP_GELU:
10141014 case GGML_UNARY_OP_SILU:
10151015 case GGML_UNARY_OP_RELU:
1016- return ggml_is_contiguous (op->src [0 ]);
1016+ return ggml_is_contiguous (op->src [0 ]) && op-> src [ 0 ]-> type == GGML_TYPE_F32 ;
10171017 default :
10181018 return false ;
10191019 }
10201020 case GGML_OP_CLAMP:
1021+ return op->src [0 ]->type == GGML_TYPE_F32;
10211022 case GGML_OP_SOFT_MAX:
10221023 case GGML_OP_NORM:
10231024 case GGML_OP_RMS_NORM:
@@ -2573,26 +2574,33 @@ static void ggml_cl_norm(ggml_backend_t backend, const ggml_tensor * src0, const
25732574 memcpy (&eps, dst->op_params , sizeof (float ));
25742575
25752576 const int ne00 = src0 ? src0->ne [0 ] : 0 ;
2576- const cl_ulong nb01 = src0 ? src0->nb [1 ] : 0 ;
2577+ const int ne01 = src0 ? src0->ne [1 ] : 0 ;
2578+ const int ne02 = src0 ? src0->ne [2 ] : 0 ;
2579+ const int ne03 = src0 ? src0->ne [3 ] : 0 ;
25772580
2578- GGML_ASSERT (ggml_is_contiguous_1 (src0));
2581+ const cl_ulong nb01 = src0 ? src0->nb [1 ] : 0 ;
2582+ const cl_ulong nb02 = src0 ? src0->nb [2 ] : 0 ;
2583+ const cl_ulong nb03 = src0 ? src0->nb [3 ] : 0 ;
25792584
25802585 const int nth = MIN (64 , ne00);
25812586
25822587 cl_kernel kernel = backend_ctx->kernel_norm ;
25832588
2584- CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
2585- CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
2586- CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extrad->data_device ));
2587- CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offsetd));
2588- CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (int ), &ne00));
2589- CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &nb01));
2590- CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (float ), &eps));
2591- CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (float )*nth, NULL ));
2592-
2593- const int64_t nrows = ggml_nrows (src0);
2589+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
2590+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
2591+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extrad->data_device ));
2592+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offsetd));
2593+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (int ), &ne00));
2594+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (int ), &ne01));
2595+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne02));
2596+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (int ), &ne03));
2597+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (cl_ulong), &nb01));
2598+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (cl_ulong), &nb02));
2599+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (cl_ulong), &nb03));
2600+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (float ), &eps));
2601+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (float )*nth, NULL ));
25942602
2595- size_t global_work_size[] = {(size_t )nrows *nth, 1 , 1 };
2603+ size_t global_work_size[] = {(size_t )ne01 *nth, ( size_t )ne02, ( size_t )ne03 };
25962604 size_t local_work_size[] = {(size_t )nth, 1 , 1 };
25972605
25982606#ifdef GGML_OPENCL_PROFILING
@@ -2630,16 +2638,19 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c
26302638 memcpy (&eps, dst->op_params , sizeof (float ));
26312639
26322640 const int ne00 = src0 ? src0->ne [0 ] : 0 ;
2641+ const int ne01 = src0 ? src0->ne [1 ] : 0 ;
2642+ const int ne02 = src0 ? src0->ne [2 ] : 0 ;
2643+ const int ne03 = src0 ? src0->ne [3 ] : 0 ;
2644+
26332645 const cl_ulong nb01 = src0 ? src0->nb [1 ] : 0 ;
2646+ const cl_ulong nb02 = src0 ? src0->nb [2 ] : 0 ;
2647+ const cl_ulong nb03 = src0 ? src0->nb [3 ] : 0 ;
26342648
26352649 GGML_ASSERT (ne00 % 4 == 0 );
2636- GGML_ASSERT (ggml_is_contiguous_1 (src0));
26372650
26382651 const int nth = MIN (64 , ne00);
26392652
2640- const int64_t nrows = ggml_nrows (src0);
2641-
2642- size_t global_work_size[] = {(size_t )nrows*nth, 1 , 1 };
2653+ size_t global_work_size[] = {(size_t )ne01*nth, (size_t )ne02, (size_t )ne03};
26432654 size_t local_work_size[] = {(size_t )nth, 1 , 1 };
26442655
26452656 cl_kernel kernel = backend_ctx->kernel_rms_norm ;
@@ -2654,15 +2665,20 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c
26542665 sizeof (local_work_size), local_work_size,
26552666 sizeof (size_t ), &sgs, NULL ));
26562667
2657- CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
2658- CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
2659- CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extrad->data_device ));
2660- CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offsetd));
2661- CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (int ), &ne00));
2662- CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &nb01));
2663- CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (float ), &eps));
2668+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
2669+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
2670+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extrad->data_device ));
2671+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offsetd));
2672+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (int ), &ne00));
2673+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (int ), &ne01));
2674+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne02));
2675+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (int ), &ne03));
2676+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (cl_ulong), &nb01));
2677+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (cl_ulong), &nb02));
2678+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (cl_ulong), &nb03));
2679+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (float ), &eps));
26642680 // This is local memory - the size depends on subgroup size.
2665- CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (float )*nth/sgs, NULL ));
2681+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (float )*nth/sgs, NULL ));
26662682
26672683#ifdef GGML_OPENCL_PROFILING
26682684 cl_event evt;
0 commit comments