@@ -2573,26 +2573,33 @@ static void ggml_cl_norm(ggml_backend_t backend, const ggml_tensor * src0, const
2573
2573
memcpy (&eps, dst->op_params , sizeof (float ));
2574
2574
2575
2575
const int ne00 = src0 ? src0->ne [0 ] : 0 ;
2576
- const cl_ulong nb01 = src0 ? src0->nb [1 ] : 0 ;
2576
+ const int ne01 = src0 ? src0->ne [1 ] : 0 ;
2577
+ const int ne02 = src0 ? src0->ne [2 ] : 0 ;
2578
+ const int ne03 = src0 ? src0->ne [3 ] : 0 ;
2577
2579
2578
- GGML_ASSERT (ggml_is_contiguous_1 (src0));
2580
+ const cl_ulong nb01 = src0 ? src0->nb [1 ] : 0 ;
2581
+ const cl_ulong nb02 = src0 ? src0->nb [2 ] : 0 ;
2582
+ const cl_ulong nb03 = src0 ? src0->nb [3 ] : 0 ;
2579
2583
2580
2584
const int nth = MIN (64 , ne00);
2581
2585
2582
2586
cl_kernel kernel = backend_ctx->kernel_norm ;
2583
2587
2584
- CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
2585
- CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
2586
- CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extrad->data_device ));
2587
- CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offsetd));
2588
- CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (int ), &ne00));
2589
- CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &nb01));
2590
- CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (float ), &eps));
2591
- CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (float )*nth, NULL ));
2588
+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
2589
+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
2590
+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extrad->data_device ));
2591
+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offsetd));
2592
+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (int ), &ne00));
2593
+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (int ), &ne01));
2594
+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne02));
2595
+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (int ), &ne03));
2596
+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (cl_ulong), &nb01));
2597
+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (cl_ulong), &nb02));
2598
+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (cl_ulong), &nb03));
2599
+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (float ), &eps));
2600
+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (float )*nth, NULL ));
2592
2601
2593
- const int64_t nrows = ggml_nrows (src0);
2594
-
2595
- size_t global_work_size[] = {(size_t )nrows*nth, 1 , 1 };
2602
+ size_t global_work_size[] = {(size_t )ne01*nth, (size_t )ne02, (size_t )ne03};
2596
2603
size_t local_work_size[] = {(size_t )nth, 1 , 1 };
2597
2604
2598
2605
#ifdef GGML_OPENCL_PROFILING
0 commit comments