Skip to content

Commit 9ed11a6

Browse files
committed
metal : add ggml_set_rows implementation
ggml-ci
1 parent 681144e commit 9ed11a6

File tree

3 files changed

+131
-8
lines changed

3 files changed

+131
-8
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,22 @@ typedef struct {
521521
uint64_t nb2;
522522
} ggml_metal_kargs_get_rows;
523523

524+
typedef struct {
525+
int32_t nk0;
526+
int32_t ne01;
527+
uint64_t nb01;
528+
uint64_t nb02;
529+
uint64_t nb03;
530+
int32_t ne11;
531+
int32_t ne12;
532+
uint64_t nb10;
533+
uint64_t nb11;
534+
uint64_t nb12;
535+
uint64_t nb1;
536+
uint64_t nb2;
537+
uint64_t nb3;
538+
} ggml_metal_kargs_set_rows;
539+
524540
typedef struct {
525541
int64_t ne00;
526542
int64_t ne01;

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,9 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
202202
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
203203
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
204204
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
205+
GGML_METAL_KERNEL_TYPE_SET_ROWS_F32,
206+
GGML_METAL_KERNEL_TYPE_SET_ROWS_F16,
207+
GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16,
205208
GGML_METAL_KERNEL_TYPE_RMS_NORM,
206209
GGML_METAL_KERNEL_TYPE_L2_NORM,
207210
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -1166,6 +1169,9 @@ @implementation GGMLMetalClass
11661169
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
11671170
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
11681171
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
1172+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F32, set_rows_f32, true);
1173+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F16, set_rows_f16, true);
1174+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16, set_rows_bf16, use_bfloat);
11691175
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
11701176
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
11711177
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
@@ -1630,7 +1636,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
16301636

16311637
if (!use_bfloat) {
16321638
for (size_t i = 0, n = 3; i < n; ++i) {
1633-
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
1639+
if (op->src[i] != NULL && (op->src[i]->type == GGML_TYPE_BF16 || op->type == GGML_TYPE_BF16)) {
16341640
return false;
16351641
}
16361642
}
@@ -1798,6 +1804,10 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
17981804
{
17991805
return op->ne[3] == 1;
18001806
}
1807+
case GGML_OP_SET_ROWS:
1808+
{
1809+
return op->src[0]->type == GGML_TYPE_F32 && ggml_blck_size(op->type) == 1; // tmp
1810+
}
18011811
default:
18021812
return false;
18031813
}
@@ -3757,13 +3767,68 @@ static bool ggml_metal_encode_node(
37573767
};
37583768

37593769
[encoder setComputePipelineState:pipeline];
3760-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3761-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
3762-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
3763-
[encoder setBytes:&args length:sizeof(args) atIndex:3];
3770+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
3771+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3772+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
3773+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
37643774

37653775
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
37663776
} break;
3777+
case GGML_OP_SET_ROWS:
3778+
{
3779+
id<MTLComputePipelineState> pipeline = nil;
3780+
3781+
switch (dst->type) {
3782+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F32 ].pipeline; break;
3783+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F16 ].pipeline; break;
3784+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16 ].pipeline; break;
3785+
default: GGML_ABORT("not implemented");
3786+
}
3787+
3788+
const int32_t nk0 = ne0/ggml_blck_size(dst->type);
3789+
3790+
int nth = 32; // SIMD width
3791+
3792+
while (nth < nk0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
3793+
nth *= 2;
3794+
}
3795+
3796+
int nrptg = 1;
3797+
if (nth > nk0) {
3798+
nrptg = (nth + nk0 - 1)/nk0;
3799+
nth = nk0;
3800+
3801+
if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
3802+
nrptg--;
3803+
}
3804+
}
3805+
3806+
nth = MIN(nth, nk0);
3807+
3808+
ggml_metal_kargs_set_rows args = {
3809+
/*.nk0 =*/ nk0,
3810+
/*.ne01 =*/ ne01,
3811+
/*.nb01 =*/ nb01,
3812+
/*.nb02 =*/ nb02,
3813+
/*.nb03 =*/ nb03,
3814+
/*.ne11 =*/ ne11,
3815+
/*.ne12 =*/ ne12,
3816+
/*.nb10 =*/ nb10,
3817+
/*.nb11 =*/ nb11,
3818+
/*.nb12 =*/ nb12,
3819+
/*.nb1 =*/ nb1,
3820+
/*.nb2 =*/ nb2,
3821+
/*.nb3 =*/ nb3,
3822+
};
3823+
3824+
[encoder setComputePipelineState:pipeline];
3825+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
3826+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3827+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
3828+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
3829+
3830+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
3831+
} break;
37673832
case GGML_OP_RMS_NORM:
37683833
{
37693834
GGML_ASSERT(ne00 % 4 == 0);

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6350,10 +6350,10 @@ kernel void kernel_mul_mv_iq4_xs_f32(
63506350

63516351
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
63526352
kernel void kernel_get_rows_q(
6353+
constant ggml_metal_kargs_get_rows & args,
63536354
device const void * src0,
63546355
device const void * src1,
63556356
device float * dst,
6356-
constant ggml_metal_kargs_get_rows & args,
63576357
uint3 tgpig[[threadgroup_position_in_grid]],
63586358
uint tiitg[[thread_index_in_threadgroup]],
63596359
uint3 tptg [[threads_per_threadgroup]]) {
@@ -6373,10 +6373,10 @@ kernel void kernel_get_rows_q(
63736373

63746374
template<typename T>
63756375
kernel void kernel_get_rows_f(
6376+
constant ggml_metal_kargs_get_rows & args,
63766377
device const void * src0,
63776378
device const void * src1,
63786379
device float * dst,
6379-
constant ggml_metal_kargs_get_rows & args,
63806380
uint3 tgpig[[threadgroup_position_in_grid]],
63816381
uint tiitg[[thread_index_in_threadgroup]],
63826382
uint3 tptg [[threads_per_threadgroup]]) {
@@ -6394,10 +6394,10 @@ kernel void kernel_get_rows_f(
63946394
}
63956395

63966396
kernel void kernel_get_rows_i32(
6397+
constant ggml_metal_kargs_get_rows & args,
63976398
device const void * src0,
63986399
device const void * src1,
63996400
device int32_t * dst,
6400-
constant ggml_metal_kargs_get_rows & args,
64016401
uint3 tgpig[[threadgroup_position_in_grid]],
64026402
uint tiitg[[thread_index_in_threadgroup]],
64036403
uint3 tptg [[threads_per_threadgroup]]) {
@@ -6414,6 +6414,36 @@ kernel void kernel_get_rows_i32(
64146414
}
64156415
}
64166416

6417+
template<typename T>
6418+
kernel void kernel_set_rows_f(
6419+
constant ggml_metal_kargs_set_rows & args,
6420+
device const void * src0,
6421+
device const void * src1,
6422+
device float * dst,
6423+
uint3 tgpig[[threadgroup_position_in_grid]],
6424+
uint tiitg[[thread_index_in_threadgroup]],
6425+
uint3 tptg [[threads_per_threadgroup]]) {
6426+
const int32_t i03 = tgpig.z;
6427+
const int32_t i02 = tgpig.y;
6428+
6429+
const int32_t i12 = i03%args.ne12;
6430+
const int32_t i11 = i02%args.ne11;
6431+
6432+
const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
6433+
if (i01 >= args.ne01) {
6434+
return;
6435+
}
6436+
6437+
const int32_t i10 = i01;
6438+
const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
6439+
6440+
device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
6441+
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
6442+
6443+
for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
6444+
dst_row[ind] = (T) src_row[ind];
6445+
}
6446+
}
64176447

64186448
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
64196449
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
@@ -6837,6 +6867,18 @@ template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get
68376867
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
68386868
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
68396869

6870+
//
6871+
// set rows
6872+
//
6873+
6874+
typedef decltype(kernel_set_rows_f<float>) set_rows_f_t;
6875+
6876+
template [[host_name("kernel_set_rows_f32")]] kernel set_rows_f_t kernel_set_rows_f<float>;
6877+
template [[host_name("kernel_set_rows_f16")]] kernel set_rows_f_t kernel_set_rows_f<half>;
6878+
#if defined(GGML_METAL_USE_BF16)
6879+
template [[host_name("kernel_set_rows_bf16")]] kernel set_rows_f_t kernel_set_rows_f<bfloat>;
6880+
#endif
6881+
68406882
//
68416883
// matrix-matrix multiplication
68426884
//

0 commit comments

Comments
 (0)