@@ -3780,15 +3780,19 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c
37803780 GGML_ASSERT (dst);
37813781 GGML_ASSERT (dst->extra );
37823782
3783- const int ne00 = src0 ? src0->ne [0 ] : 0 ;
3784- const cl_ulong nb01 = src0 ? src0->nb [1 ] : 0 ;
3785- const cl_ulong nb02 = src0 ? src0->nb [2 ] : 0 ;
3786- const int ne10 = src1 ? src1->ne [0 ] : 0 ;
3787- const cl_ulong nb10 = src1 ? src1->nb [0 ] : 0 ;
3788- const int ne11 = src1 ? src1->ne [1 ] : 0 ;
3789- const cl_ulong nb11 = src1 ? src1->nb [1 ] : 0 ;
3790- const cl_ulong nb1 = dst ? dst->nb [1 ] : 0 ;
3791- const cl_ulong nb2 = dst ? dst->nb [2 ] : 0 ;
3783+ const int ne00 = src0->ne [0 ];
3784+ const cl_ulong nb01 = src0->nb [1 ];
3785+ const cl_ulong nb02 = src0->nb [2 ];
3786+ const cl_ulong nb03 = src0->nb [3 ];
3787+ const int ne10 = src1->ne [0 ];
3788+ const cl_ulong nb10 = src1->nb [0 ];
3789+ const int ne11 = src1->ne [1 ];
3790+ const int ne12 = src1->ne [2 ];
3791+ const cl_ulong nb11 = src1->nb [1 ];
3792+ const cl_ulong nb12 = src1->nb [2 ];
3793+ const cl_ulong nb1 = dst->nb [1 ];
3794+ const cl_ulong nb2 = dst->nb [2 ];
3795+ const cl_ulong nb3 = dst->nb [3 ];
37923796
37933797 ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context ;
37943798
@@ -3825,14 +3829,17 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c
38253829 CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne00));
38263830 CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (cl_ulong), &nb01));
38273831 CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (cl_ulong), &nb02));
3828- CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (int ), &ne10));
3829- CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (cl_ulong), &nb10));
3830- CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (cl_ulong), &nb11));
3831- CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (cl_ulong), &nb1));
3832- CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (cl_ulong), &nb2));
3833-
3834- size_t global_work_size[] = {(size_t )ne10, (size_t )ne11, 1 };
3835- size_t local_work_size[] = {1 , 1 , 1 };
3832+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (cl_ulong), &nb03));
3833+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (int ), &ne10));
3834+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (cl_ulong), &nb10));
3835+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (cl_ulong), &nb11));
3836+ CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (cl_ulong), &nb12));
3837+ CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (cl_ulong), &nb1));
3838+ CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (cl_ulong), &nb2));
3839+ CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (cl_ulong), &nb3));
3840+
3841+ size_t global_work_size[] = {(size_t )ne10*64 , (size_t )ne11, (size_t )ne12};
3842+ size_t local_work_size[] = {64 , 1 , 1 };
38363843
38373844 backend_ctx->enqueue_ndrange_kernel (kernel, 3 , global_work_size, local_work_size, dst);
38383845}
0 commit comments