@@ -309,6 +309,7 @@ struct ggml_backend_opencl_context {
309309 cl_program program_softmax_f16;
310310 cl_program program_softmax_4_f32;
311311 cl_program program_softmax_4_f16;
312+ cl_program program_argsort_f32_i32;
312313
313314 cl_kernel kernel_add, kernel_add_row;
314315 cl_kernel kernel_mul, kernel_mul_row;
@@ -339,6 +340,7 @@ struct ggml_backend_opencl_context {
339340 cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat;
340341 cl_kernel kernel_mul_mv_q6_K_f32;
341342 cl_kernel kernel_im2col_f32, kernel_im2col_f16;
343+ cl_kernel kernel_argsort_f32_i32;
342344
343345#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
344346 // Transpose kernels
@@ -986,6 +988,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
986988 GGML_LOG_CONT (" ." );
987989 }
988990
991+ // argsort
992+ {
993+ #ifdef GGML_OPENCL_EMBED_KERNELS
994+ const std::string kernel_src {
995+ #include " argsort.cl.h"
996+ };
997+ #else
998+ const std::string kernel_src = read_file (" argsort.cl" );
999+ #endif
1000+ backend_ctx->program_argsort_f32_i32 =
1001+ build_program_from_source (backend_ctx->context , backend_ctx->device , kernel_src.c_str (), compile_opts);
1002+
1003+ CL_CHECK ((backend_ctx->kernel_argsort_f32_i32 = clCreateKernel (backend_ctx->program_argsort_f32_i32 , " kernel_argsort_f32_i32" , &err), err));
1004+ GGML_LOG_CONT (" ." );
1005+ }
1006+
9891007 // Adreno kernels
9901008#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
9911009 // transpose
@@ -1912,6 +1930,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
19121930 }
19131931 case GGML_OP_IM2COL:
19141932 return true ;
1933+ case GGML_OP_ARGSORT:
1934+ return op->src [0 ]->type == GGML_TYPE_F32;
19151935 default :
19161936 return false ;
19171937 }
@@ -4975,6 +4995,61 @@ static void ggml_cl_im2col(ggml_backend_t backend, const ggml_tensor * src0, con
49754995#endif
49764996}
49774997
4998+ static void ggml_cl_argsort (ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
4999+ GGML_ASSERT (src0);
5000+ GGML_ASSERT (src0->extra );
5001+ GGML_ASSERT (dst);
5002+ GGML_ASSERT (dst->extra );
5003+ GGML_UNUSED (src1);
5004+
5005+ GGML_ASSERT (src0->type == GGML_TYPE_F32);
5006+ GGML_ASSERT ( dst->type == GGML_TYPE_I32);
5007+ GGML_ASSERT (ggml_is_contiguous (src0));
5008+
5009+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context ;
5010+ cl_command_queue queue = backend_ctx->queue ;
5011+
5012+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra ;
5013+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra ;
5014+
5015+ cl_ulong offset0 = extra0->offset + src0->view_offs ;
5016+ cl_ulong offsetd = extrad->offset + dst->view_offs ;
5017+
5018+ const int ne00 = src0->ne [0 ];
5019+ const int nrows = ggml_nrows (src0);
5020+
5021+ int ne00_padded = 1 ;
5022+ while (ne00_padded < ne00) {
5023+ ne00_padded *= 2 ;
5024+ }
5025+
5026+ int order = (enum ggml_sort_order) dst->op_params [0 ];
5027+
5028+ cl_kernel kernel = backend_ctx->kernel_argsort_f32_i32 ;
5029+
5030+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
5031+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
5032+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extrad->data_device ));
5033+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offsetd));
5034+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (int ), &ne00));
5035+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (int ), &ne00_padded));
5036+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &order));
5037+ CL_CHECK (clSetKernelArg (kernel, 7 , ne00_padded*sizeof (int ), NULL ));
5038+
5039+ size_t global_work_size[] = {(size_t )ne00_padded, (size_t )nrows, (size_t )1 };
5040+ size_t local_work_size[] = {(size_t )ne00_padded, 1 , 1 };
5041+
5042+ #ifdef GGML_OPENCL_PROFILING
5043+ cl_event evt;
5044+ CL_CHECK (clEnqueueNDRangeKernel (queue, kernel, 3 , NULL , global_work_size, local_work_size, 0 , NULL , &evt));
5045+
5046+ g_profiling_info.emplace_back ();
5047+ populateProfilingInfo (g_profiling_info.back (), evt, kernel, global_work_size, local_work_size, dst);
5048+ #else
5049+ CL_CHECK (clEnqueueNDRangeKernel (queue, kernel, 3 , NULL , global_work_size, local_work_size, 0 , NULL , NULL ));
5050+ #endif
5051+ }
5052+
49785053// ------------------------------------------------------------------------------
49795054// Op offloading
49805055// ------------------------------------------------------------------------------
@@ -5115,6 +5190,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
51155190 }
51165191 func = ggml_cl_im2col;
51175192 break ;
5193+ case GGML_OP_ARGSORT:
5194+ if (!any_on_device) {
5195+ return false ;
5196+ }
5197+ func = ggml_cl_argsort;
5198+ break ;
51185199 default :
51195200 return false ;
51205201 }
0 commit comments