@@ -4222,15 +4222,19 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c
4222
4222
GGML_ASSERT (dst);
4223
4223
GGML_ASSERT (dst->extra );
4224
4224
4225
- const int ne00 = src0 ? src0->ne [0 ] : 0 ;
4226
- const cl_ulong nb01 = src0 ? src0->nb [1 ] : 0 ;
4227
- const cl_ulong nb02 = src0 ? src0->nb [2 ] : 0 ;
4228
- const int ne10 = src1 ? src1->ne [0 ] : 0 ;
4229
- const cl_ulong nb10 = src1 ? src1->nb [0 ] : 0 ;
4230
- const int ne11 = src1 ? src1->ne [1 ] : 0 ;
4231
- const cl_ulong nb11 = src1 ? src1->nb [1 ] : 0 ;
4232
- const cl_ulong nb1 = dst ? dst->nb [1 ] : 0 ;
4233
- const cl_ulong nb2 = dst ? dst->nb [2 ] : 0 ;
4225
+ const int ne00 = src0->ne [0 ];
4226
+ const cl_ulong nb01 = src0->nb [1 ];
4227
+ const cl_ulong nb02 = src0->nb [2 ];
4228
+ const cl_ulong nb03 = src0->nb [3 ];
4229
+ const int ne10 = src1->ne [0 ];
4230
+ const cl_ulong nb10 = src1->nb [0 ];
4231
+ const int ne11 = src1->ne [1 ];
4232
+ const int ne12 = src1->ne [2 ];
4233
+ const cl_ulong nb11 = src1->nb [1 ];
4234
+ const cl_ulong nb12 = src1->nb [2 ];
4235
+ const cl_ulong nb1 = dst->nb [1 ];
4236
+ const cl_ulong nb2 = dst->nb [2 ];
4237
+ const cl_ulong nb3 = dst->nb [3 ];
4234
4238
4235
4239
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context ;
4236
4240
@@ -4267,14 +4271,17 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c
4267
4271
CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne00));
4268
4272
CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (cl_ulong), &nb01));
4269
4273
CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (cl_ulong), &nb02));
4270
- CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (int ), &ne10));
4271
- CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (cl_ulong), &nb10));
4272
- CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (cl_ulong), &nb11));
4273
- CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (cl_ulong), &nb1));
4274
- CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (cl_ulong), &nb2));
4275
-
4276
- size_t global_work_size[] = {(size_t )ne10, (size_t )ne11, 1 };
4277
- size_t local_work_size[] = {1 , 1 , 1 };
4274
+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (cl_ulong), &nb03));
4275
+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (int ), &ne10));
4276
+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (cl_ulong), &nb10));
4277
+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (cl_ulong), &nb11));
4278
+ CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (cl_ulong), &nb12));
4279
+ CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (cl_ulong), &nb1));
4280
+ CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (cl_ulong), &nb2));
4281
+ CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (cl_ulong), &nb3));
4282
+
4283
+ size_t global_work_size[] = {(size_t )ne10*64 , (size_t )ne11, (size_t )ne12};
4284
+ size_t local_work_size[] = {64 , 1 , 1 };
4278
4285
4279
4286
backend_ctx->enqueue_ndrange_kernel (kernel, 3 , global_work_size, local_work_size, dst);
4280
4287
}
0 commit comments