@@ -2347,58 +2347,6 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
23472347 }
23482348}
23492349
2350- static void argmax_f32_i32_sycl (const float *x, int *dst, const int ncols,
2351- const int nrows, queue_ptr stream) {
2352- const sycl::range<3 > block_dims (1 , 1 , SYCL_ARGMAX_BLOCK_SIZE);
2353- const sycl::range<3 > block_nums (1 , nrows, 1 );
2354- const size_t shared_mem = 256 * sizeof (float );
2355-
2356- stream->submit ([&](sycl::handler &cgh) {
2357- sycl::local_accessor<float , 1 > shared_data (
2358- sycl::range<1 >(shared_mem/sizeof (float )), cgh);
2359- sycl::local_accessor<int , 1 > shared_indices (
2360- sycl::range<1 >(shared_mem/sizeof (float )), cgh);
2361-
2362- cgh.parallel_for (
2363- sycl::nd_range<3 >(block_nums * block_dims, block_dims),
2364- [=](sycl::nd_item<3 > item_ct1) {
2365- const int tid = item_ct1.get_local_id (2 );
2366- const int row = item_ct1.get_global_id (1 );
2367-
2368- float max_val = -INFINITY;
2369- int max_idx = -1 ;
2370-
2371- for (int col = tid; col < ncols; col += 256 ) {
2372- float val = x[row * ncols + col];
2373- if (val > max_val) {
2374- max_val = val;
2375- max_idx = col;
2376- }
2377- }
2378-
2379- shared_data[tid] = max_val;
2380- shared_indices[tid] = max_idx;
2381- item_ct1.barrier (sycl::access::fence_space::local_space);
2382-
2383- for (int stride = 256 /2 ; stride > 0 ; stride >>= 1 ) {
2384- if (tid < stride) {
2385- float val1 = shared_data[tid];
2386- float val2 = shared_data[tid + stride];
2387- if (val2 > val1) {
2388- shared_data[tid] = val2;
2389- shared_indices[tid] = shared_indices[tid + stride];
2390- }
2391- }
2392- item_ct1.barrier (sycl::access::fence_space::local_space);
2393- }
2394-
2395-
2396- if (tid == 0 ) {
2397- dst[row] = shared_indices[0 ];
2398- }
2399- });
2400- });
2401- }
24022350static void diag_mask_inf_f32_sycl (const float *x, float *dst,
24032351 const int ncols_x, const int nrows_x,
24042352 const int rows_per_channel, const int n_past,
@@ -2746,22 +2694,6 @@ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor *
27462694 argsort_f32_i32_sycl (src0_dd, dst_dd, ncols, nrows, order, main_stream);
27472695}
27482696
2749- inline void ggml_sycl_op_argmax (ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
2750- GGML_ASSERT (ggml_is_contiguous (dst->src [0 ]));
2751-
2752- GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
2753- GGML_ASSERT (dst->type == GGML_TYPE_I32);
2754-
2755- const int64_t ncols = dst->src [0 ]->ne [0 ];
2756- const int64_t nrows = ggml_nrows (dst->src [0 ]);
2757-
2758- dpct::queue_ptr main_stream = ctx.stream ();
2759- const float * src0_dd = static_cast <const float *>(dst->src [0 ]->data );
2760- int32_t * dst_dd = static_cast <int32_t *>(dst->data );
2761-
2762- argmax_f32_i32_sycl (src0_dd, dst_dd, ncols, nrows, main_stream);
2763- }
2764-
27652697inline void ggml_sycl_op_diag_mask_inf (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
27662698
27672699 GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
0 commit comments