@@ -3780,15 +3780,19 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c
3780
3780
GGML_ASSERT (dst);
3781
3781
GGML_ASSERT (dst->extra );
3782
3782
3783
- const int ne00 = src0 ? src0->ne [0 ] : 0 ;
3784
- const cl_ulong nb01 = src0 ? src0->nb [1 ] : 0 ;
3785
- const cl_ulong nb02 = src0 ? src0->nb [2 ] : 0 ;
3786
- const int ne10 = src1 ? src1->ne [0 ] : 0 ;
3787
- const cl_ulong nb10 = src1 ? src1->nb [0 ] : 0 ;
3788
- const int ne11 = src1 ? src1->ne [1 ] : 0 ;
3789
- const cl_ulong nb11 = src1 ? src1->nb [1 ] : 0 ;
3790
- const cl_ulong nb1 = dst ? dst->nb [1 ] : 0 ;
3791
- const cl_ulong nb2 = dst ? dst->nb [2 ] : 0 ;
3783
+ const int ne00 = src0->ne [0 ];
3784
+ const cl_ulong nb01 = src0->nb [1 ];
3785
+ const cl_ulong nb02 = src0->nb [2 ];
3786
+ const cl_ulong nb03 = src0->nb [3 ];
3787
+ const int ne10 = src1->ne [0 ];
3788
+ const cl_ulong nb10 = src1->nb [0 ];
3789
+ const int ne11 = src1->ne [1 ];
3790
+ const int ne12 = src1->ne [2 ];
3791
+ const cl_ulong nb11 = src1->nb [1 ];
3792
+ const cl_ulong nb12 = src1->nb [2 ];
3793
+ const cl_ulong nb1 = dst->nb [1 ];
3794
+ const cl_ulong nb2 = dst->nb [2 ];
3795
+ const cl_ulong nb3 = dst->nb [3 ];
3792
3796
3793
3797
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context ;
3794
3798
@@ -3825,14 +3829,17 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c
3825
3829
CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne00));
3826
3830
CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (cl_ulong), &nb01));
3827
3831
CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (cl_ulong), &nb02));
3828
- CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (int ), &ne10));
3829
- CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (cl_ulong), &nb10));
3830
- CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (cl_ulong), &nb11));
3831
- CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (cl_ulong), &nb1));
3832
- CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (cl_ulong), &nb2));
3833
-
3834
- size_t global_work_size[] = {(size_t )ne10, (size_t )ne11, 1 };
3835
- size_t local_work_size[] = {1 , 1 , 1 };
3832
+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (cl_ulong), &nb03));
3833
+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (int ), &ne10));
3834
+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (cl_ulong), &nb10));
3835
+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (cl_ulong), &nb11));
3836
+ CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (cl_ulong), &nb12));
3837
+ CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (cl_ulong), &nb1));
3838
+ CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (cl_ulong), &nb2));
3839
+ CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (cl_ulong), &nb3));
3840
+
3841
+ size_t global_work_size[] = {(size_t )ne10*64 , (size_t )ne11, (size_t )ne12};
3842
+ size_t local_work_size[] = {64 , 1 , 1 };
3836
3843
3837
3844
backend_ctx->enqueue_ndrange_kernel (kernel, 3 , global_work_size, local_work_size, dst);
3838
3845
}
0 commit comments