@@ -2796,8 +2796,7 @@ kernel void kernel_timestep_embedding_f32(
27962796typedef void (argsort_t )(
27972797 device const float * x,
27982798 device int32_t * dst,
2799- constant int64_t & ncols,
2800- constant int64_t & ncols_pad,
2799+ constant ggml_metal_kargs_argsort & args,
28012800 threadgroup int32_t * shared_values [[threadgroup(0 )]],
28022801 uint3 tgpig[[threadgroup_position_in_grid]],
28032802 uint3 tpitg[[thread_position_in_threadgroup]]);
@@ -2806,40 +2805,39 @@ template<ggml_sort_order order>
28062805kernel void kernel_argsort_f32_i32 (
28072806 device const float * x,
28082807 device int32_t * dst,
2809- constant int64_t & ncols,
2810- constant int64_t & ncols_pad,
2808+ constant ggml_metal_kargs_argsort & args,
28112809 threadgroup int32_t * shared_values [[threadgroup(0 )]],
28122810 uint3 tgpig[[threadgroup_position_in_grid]],
28132811 uint3 tpitg[[thread_position_in_threadgroup]]) {
28142812 // bitonic sort
28152813 int col = tpitg[0 ];
28162814 int row = tgpig[1 ];
28172815
2818- if (col >= ncols_pad) return ;
2816+ if (col >= args. ncols_pad ) return ;
28192817
2820- device const float * x_row = x + row * ncols;
2818+ device const float * x_row = x + row * args. ncols ;
28212819 threadgroup int32_t * dst_row = shared_values;
28222820
28232821 // initialize indices
28242822 dst_row[col] = col;
28252823
28262824 threadgroup_barrier (mem_flags::mem_threadgroup);
28272825
2828- for (int k = 2 ; k <= ncols_pad; k *= 2 ) {
2826+ for (int k = 2 ; k <= args. ncols_pad ; k *= 2 ) {
28292827 for (int j = k / 2 ; j > 0 ; j /= 2 ) {
28302828 int ixj = col ^ j;
28312829 if (ixj > col) {
28322830 if ((col & k) == 0 ) {
2833- if (dst_row[col] >= ncols ||
2834- (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
2831+ if (dst_row[col] >= args. ncols ||
2832+ (dst_row[ixj] < args. ncols && (order == GGML_SORT_ORDER_ASC ?
28352833 x_row[dst_row[col]] > x_row[dst_row[ixj]] :
28362834 x_row[dst_row[col]] < x_row[dst_row[ixj]]))
28372835 ) {
28382836 SWAP (dst_row[col], dst_row[ixj]);
28392837 }
28402838 } else {
2841- if (dst_row[ixj] >= ncols ||
2842- (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
2839+ if (dst_row[ixj] >= args. ncols ||
2840+ (dst_row[col] < args. ncols && (order == GGML_SORT_ORDER_ASC ?
28432841 x_row[dst_row[col]] < x_row[dst_row[ixj]] :
28442842 x_row[dst_row[col]] > x_row[dst_row[ixj]]))
28452843 ) {
@@ -2852,8 +2850,8 @@ kernel void kernel_argsort_f32_i32(
28522850 }
28532851
28542852 // copy the result to dst without the padding
2855- if (col < ncols) {
2856- dst[row * ncols + col] = dst_row[col];
2853+ if (col < args. ncols ) {
2854+ dst[row * args. ncols + col] = dst_row[col];
28572855 }
28582856}
28592857
0 commit comments