@@ -5885,28 +5885,21 @@ kernel void kernel_get_rows_q(
58855885 device const void * src0,
58865886 device const void * src1,
58875887 device float * dst,
5888- constant int64_t & ne00,
5889- constant uint64_t & nb01,
5890- constant uint64_t & nb02,
5891- constant int64_t & ne10,
5892- constant uint64_t & nb10,
5893- constant uint64_t & nb11,
5894- constant uint64_t & nb1,
5895- constant uint64_t & nb2,
5888+ constant ggml_metal_kargs_get_rows & args,
58965889 uint3 tgpig[[threadgroup_position_in_grid]],
58975890 uint tiitg[[thread_index_in_threadgroup]],
58985891 uint3 tptg [[threads_per_threadgroup]]) {
58995892 const int64_t i10 = tgpig.x ;
59005893 const int64_t i11 = tgpig.y ;
59015894
5902- const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0 ];
5895+ const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args. nb11 + i10*args. nb10 ))[0 ];
59035896
59045897 const int64_t i02 = i11;
59055898
5906- for (int64_t ind = tiitg; ind < ne00/16 ; ind += tptg.x ) {
5899+ for (int64_t ind = tiitg; ind < args. ne00 /16 ; ind += tptg.x ) {
59075900 float4x4 temp;
5908- dequantize_func (((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
5909- *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
5901+ dequantize_func (((device const block_q *) ((const device char *) src0 + r*args. nb01 + i02*args. nb02 )) + ind/nl, ind%nl, temp);
5902+ *(((device float4x4 *) ((device char *) dst + i11*args. nb2 + i10*args. nb1 )) + ind) = temp;
59105903 }
59115904}
59125905
@@ -5915,55 +5908,41 @@ kernel void kernel_get_rows_f(
59155908 device const void * src0,
59165909 device const void * src1,
59175910 device float * dst,
5918- constant int64_t & ne00,
5919- constant uint64_t & nb01,
5920- constant uint64_t & nb02,
5921- constant int64_t & ne10,
5922- constant uint64_t & nb10,
5923- constant uint64_t & nb11,
5924- constant uint64_t & nb1,
5925- constant uint64_t & nb2,
5911+ constant ggml_metal_kargs_get_rows & args,
59265912 uint3 tgpig[[threadgroup_position_in_grid]],
59275913 uint tiitg[[thread_index_in_threadgroup]],
59285914 uint3 tptg [[threads_per_threadgroup]]) {
59295915 const int64_t i10 = tgpig.x ;
59305916 const int64_t i11 = tgpig.y ;
59315917
5932- const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0 ];
5918+ const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args. nb11 + i10*args. nb10 ))[0 ];
59335919
59345920 const int64_t i02 = i11;
59355921
5936- for (int ind = tiitg; ind < ne00; ind += tptg.x ) {
5937- (( device float *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
5938- ((const device T *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
5922+ for (int ind = tiitg; ind < args. ne00 ; ind += tptg.x ) {
5923+ (( device float *) (( device char *) dst + i11*args. nb2 + i10*args. nb1 ))[ind] =
5924+ ((const device T *) ((const device char *) src0 + i02*args. nb02 + r*args. nb01 ))[ind];
59395925 }
59405926}
59415927
59425928kernel void kernel_get_rows_i32 (
59435929 device const void * src0,
59445930 device const void * src1,
59455931 device int32_t * dst,
5946- constant int64_t & ne00,
5947- constant uint64_t & nb01,
5948- constant uint64_t & nb02,
5949- constant int64_t & ne10,
5950- constant uint64_t & nb10,
5951- constant uint64_t & nb11,
5952- constant uint64_t & nb1,
5953- constant uint64_t & nb2,
5932+ constant ggml_metal_kargs_get_rows & args,
59545933 uint3 tgpig[[threadgroup_position_in_grid]],
59555934 uint tiitg[[thread_index_in_threadgroup]],
59565935 uint3 tptg [[threads_per_threadgroup]]) {
59575936 const int64_t i10 = tgpig.x ;
59585937 const int64_t i11 = tgpig.y ;
59595938
5960- const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0 ];
5939+ const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args. nb11 + i10*args. nb10 ))[0 ];
59615940
59625941 const int64_t i02 = i11;
59635942
5964- for (int ind = tiitg; ind < ne00; ind += tptg.x ) {
5965- (( device int32_t *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
5966- ((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
5943+ for (int ind = tiitg; ind < args. ne00 ; ind += tptg.x ) {
5944+ (( device int32_t *) (( device char *) dst + i11*args. nb2 + i10*args. nb1 ))[ind] =
5945+ ((const device int32_t *) ((const device char *) src0 + i02*args. nb02 + r*args. nb01 ))[ind];
59675946 }
59685947}
59695948
0 commit comments