@@ -2637,16 +2637,19 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c
2637
2637
memcpy (&eps, dst->op_params , sizeof (float ));
2638
2638
2639
2639
const int ne00 = src0 ? src0->ne [0 ] : 0 ;
2640
+ const int ne01 = src0 ? src0->ne [1 ] : 0 ;
2641
+ const int ne02 = src0 ? src0->ne [2 ] : 0 ;
2642
+ const int ne03 = src0 ? src0->ne [3 ] : 0 ;
2643
+
2640
2644
const cl_ulong nb01 = src0 ? src0->nb [1 ] : 0 ;
2645
+ const cl_ulong nb02 = src0 ? src0->nb [2 ] : 0 ;
2646
+ const cl_ulong nb03 = src0 ? src0->nb [3 ] : 0 ;
2641
2647
2642
2648
GGML_ASSERT (ne00 % 4 == 0 );
2643
- GGML_ASSERT (ggml_is_contiguous_1 (src0));
2644
2649
2645
2650
const int nth = MIN (64 , ne00);
2646
2651
2647
- const int64_t nrows = ggml_nrows (src0);
2648
-
2649
- size_t global_work_size[] = {(size_t )nrows*nth, 1 , 1 };
2652
+ size_t global_work_size[] = {(size_t )ne01*nth, (size_t )ne02, (size_t )ne03};
2650
2653
size_t local_work_size[] = {(size_t )nth, 1 , 1 };
2651
2654
2652
2655
cl_kernel kernel = backend_ctx->kernel_rms_norm ;
@@ -2661,15 +2664,20 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c
2661
2664
sizeof (local_work_size), local_work_size,
2662
2665
sizeof (size_t ), &sgs, NULL ));
2663
2666
2664
- CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
2665
- CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
2666
- CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extrad->data_device ));
2667
- CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offsetd));
2668
- CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (int ), &ne00));
2669
- CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &nb01));
2670
- CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (float ), &eps));
2667
+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
2668
+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
2669
+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extrad->data_device ));
2670
+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offsetd));
2671
+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (int ), &ne00));
2672
+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (int ), &ne01));
2673
+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne02));
2674
+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (int ), &ne03));
2675
+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (cl_ulong), &nb01));
2676
+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (cl_ulong), &nb02));
2677
+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (cl_ulong), &nb03));
2678
+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (float ), &eps));
2671
2679
// This is local memory - the size depends on subgroup size.
2672
- CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (float )*nth/sgs, NULL ));
2680
+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (float )*nth/sgs, NULL ));
2673
2681
2674
2682
#ifdef GGML_OPENCL_PROFILING
2675
2683
cl_event evt;
0 commit comments