Skip to content

Commit e19266b

Browse files
author
alexju
committed
metal : refactor argsort parameters into a struct
1 parent 53d36ee commit e19266b

File tree

3 files changed

+24
-18
lines changed

3 files changed

+24
-18
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,11 @@ typedef struct {
491491
int max_period;
492492
} ggml_metal_kargs_timestep_embedding;
493493

494+
typedef struct {
495+
int64_t ncols;
496+
int64_t ncols_pad;
497+
} ggml_metal_kargs_argsort;
498+
494499
typedef struct {
495500
int64_t ne0;
496501
float start;

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3595,12 +3595,15 @@ static void ggml_metal_encode_node(
35953595
default: GGML_ABORT("fatal error");
35963596
};
35973597

3598-
// TODO: add ggml_metal_kargs struct
3598+
ggml_metal_kargs_argsort args = {
3599+
/*.ncols =*/ ne00,
3600+
/*.ncols_pad =*/ ne00_padded
3601+
};
3602+
35993603
[encoder setComputePipelineState:pipeline];
3600-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3601-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3602-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
3603-
[encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
3604+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3605+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3606+
[encoder setBytes:&args length:sizeof(args) atIndex:2];
36043607
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
36053608

36063609
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2796,8 +2796,7 @@ kernel void kernel_timestep_embedding_f32(
27962796
typedef void (argsort_t)(
27972797
device const float * x,
27982798
device int32_t * dst,
2799-
constant int64_t & ncols,
2800-
constant int64_t & ncols_pad,
2799+
constant ggml_metal_kargs_argsort & args,
28012800
threadgroup int32_t * shared_values [[threadgroup(0)]],
28022801
uint3 tgpig[[threadgroup_position_in_grid]],
28032802
uint3 tpitg[[thread_position_in_threadgroup]]);
@@ -2806,40 +2805,39 @@ template<ggml_sort_order order>
28062805
kernel void kernel_argsort_f32_i32(
28072806
device const float * x,
28082807
device int32_t * dst,
2809-
constant int64_t & ncols,
2810-
constant int64_t & ncols_pad,
2808+
constant ggml_metal_kargs_argsort & args,
28112809
threadgroup int32_t * shared_values [[threadgroup(0)]],
28122810
uint3 tgpig[[threadgroup_position_in_grid]],
28132811
uint3 tpitg[[thread_position_in_threadgroup]]) {
28142812
// bitonic sort
28152813
int col = tpitg[0];
28162814
int row = tgpig[1];
28172815

2818-
if (col >= ncols_pad) return;
2816+
if (col >= args.ncols_pad) return;
28192817

2820-
device const float * x_row = x + row * ncols;
2818+
device const float * x_row = x + row * args.ncols;
28212819
threadgroup int32_t * dst_row = shared_values;
28222820

28232821
// initialize indices
28242822
dst_row[col] = col;
28252823

28262824
threadgroup_barrier(mem_flags::mem_threadgroup);
28272825

2828-
for (int k = 2; k <= ncols_pad; k *= 2) {
2826+
for (int k = 2; k <= args.ncols_pad; k *= 2) {
28292827
for (int j = k / 2; j > 0; j /= 2) {
28302828
int ixj = col ^ j;
28312829
if (ixj > col) {
28322830
if ((col & k) == 0) {
2833-
if (dst_row[col] >= ncols ||
2834-
(dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
2831+
if (dst_row[col] >= args.ncols ||
2832+
(dst_row[ixj] < args.ncols && (order == GGML_SORT_ORDER_ASC ?
28352833
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
28362834
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
28372835
) {
28382836
SWAP(dst_row[col], dst_row[ixj]);
28392837
}
28402838
} else {
2841-
if (dst_row[ixj] >= ncols ||
2842-
(dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
2839+
if (dst_row[ixj] >= args.ncols ||
2840+
(dst_row[col] < args.ncols && (order == GGML_SORT_ORDER_ASC ?
28432841
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
28442842
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
28452843
) {
@@ -2852,8 +2850,8 @@ kernel void kernel_argsort_f32_i32(
28522850
}
28532851

28542852
// copy the result to dst without the padding
2855-
if (col < ncols) {
2856-
dst[row * ncols + col] = dst_row[col];
2853+
if (col < args.ncols) {
2854+
dst[row * args.ncols + col] = dst_row[col];
28572855
}
28582856
}
28592857

0 commit comments

Comments
 (0)