Skip to content

Commit 5983eb1

Browse files
committed
metal : add ggml_set_rows implementation
ggml-ci
1 parent 81bb28f commit 5983eb1

File tree

3 files changed

+115
-8
lines changed

3 files changed

+115
-8
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,21 @@ typedef struct {
521521
uint64_t nb2;
522522
} ggml_metal_kargs_get_rows;
523523

524+
typedef struct {
525+
int32_t nbk0;
526+
uint64_t nb01;
527+
uint64_t nb02;
528+
uint64_t nb03;
529+
int32_t ne11;
530+
int32_t ne12;
531+
uint64_t nb10;
532+
uint64_t nb11;
533+
uint64_t nb12;
534+
uint64_t nb1;
535+
uint64_t nb2;
536+
uint64_t nb3;
537+
} ggml_metal_kargs_set_rows;
538+
524539
typedef struct {
525540
int64_t ne00;
526541
int64_t ne01;

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,9 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
202202
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
203203
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
204204
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
205+
GGML_METAL_KERNEL_TYPE_SET_ROWS_F32,
206+
GGML_METAL_KERNEL_TYPE_SET_ROWS_F16,
207+
GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16,
205208
GGML_METAL_KERNEL_TYPE_RMS_NORM,
206209
GGML_METAL_KERNEL_TYPE_L2_NORM,
207210
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -1166,6 +1169,9 @@ @implementation GGMLMetalClass
11661169
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
11671170
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
11681171
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
1172+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F32, set_rows_f32, true);
1173+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F16, set_rows_f16, true);
1174+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16, set_rows_bf16, use_bfloat);
11691175
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
11701176
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
11711177
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
@@ -1630,7 +1636,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
16301636

16311637
if (!use_bfloat) {
16321638
for (size_t i = 0, n = 3; i < n; ++i) {
1633-
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
1639+
if (op->src[i] != NULL && (op->src[i]->type == GGML_TYPE_BF16 || op->type == GGML_TYPE_BF16)) {
16341640
return false;
16351641
}
16361642
}
@@ -1798,6 +1804,10 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
17981804
{
17991805
return op->ne[3] == 1;
18001806
}
1807+
case GGML_OP_SET_ROWS:
1808+
{
1809+
return op->src[0]->type == GGML_TYPE_F32 && ggml_blck_size(op->type) == 1; // tmp
1810+
}
18011811
default:
18021812
return false;
18031813
}
@@ -3757,13 +3767,57 @@ static bool ggml_metal_encode_node(
37573767
};
37583768

37593769
[encoder setComputePipelineState:pipeline];
3760-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3761-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
3762-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
3763-
[encoder setBytes:&args length:sizeof(args) atIndex:3];
3770+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
3771+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3772+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
3773+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
37643774

37653775
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
37663776
} break;
3777+
case GGML_OP_SET_ROWS:
3778+
{
3779+
id<MTLComputePipelineState> pipeline = nil;
3780+
3781+
switch (dst->type) {
3782+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F32 ].pipeline; break;
3783+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F16 ].pipeline; break;
3784+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16 ].pipeline; break;
3785+
default: GGML_ABORT("not implemented");
3786+
}
3787+
3788+
const int32_t nbk0 = ne0/ggml_blck_size(dst->type);
3789+
3790+
int nth = 32; // SIMD width
3791+
3792+
while (nth < nbk0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
3793+
nth *= 2;
3794+
}
3795+
3796+
nth = MIN(nth, nbk0);
3797+
3798+
ggml_metal_kargs_set_rows args = {
3799+
/*.nbk0 =*/ nbk0,
3800+
/*.nb01 =*/ nb01,
3801+
/*.nb02 =*/ nb02,
3802+
/*.nb03 =*/ nb03,
3803+
/*.ne11 =*/ ne11,
3804+
/*.ne12 =*/ ne12,
3805+
/*.nb10 =*/ nb10,
3806+
/*.nb11 =*/ nb11,
3807+
/*.nb12 =*/ nb12,
3808+
/*.nb1 =*/ nb1,
3809+
/*.nb2 =*/ nb2,
3810+
/*.nb3 =*/ nb3,
3811+
};
3812+
3813+
[encoder setComputePipelineState:pipeline];
3814+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
3815+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3816+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
3817+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
3818+
3819+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3820+
} break;
37673821
case GGML_OP_RMS_NORM:
37683822
{
37693823
GGML_ASSERT(ne00 % 4 == 0);

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6350,10 +6350,10 @@ kernel void kernel_mul_mv_iq4_xs_f32(
63506350

63516351
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
63526352
kernel void kernel_get_rows_q(
6353+
constant ggml_metal_kargs_get_rows & args,
63536354
device const void * src0,
63546355
device const void * src1,
63556356
device float * dst,
6356-
constant ggml_metal_kargs_get_rows & args,
63576357
uint3 tgpig[[threadgroup_position_in_grid]],
63586358
uint tiitg[[thread_index_in_threadgroup]],
63596359
uint3 tptg [[threads_per_threadgroup]]) {
@@ -6373,10 +6373,10 @@ kernel void kernel_get_rows_q(
63736373

63746374
template<typename T>
63756375
kernel void kernel_get_rows_f(
6376+
constant ggml_metal_kargs_get_rows & args,
63766377
device const void * src0,
63776378
device const void * src1,
63786379
device float * dst,
6379-
constant ggml_metal_kargs_get_rows & args,
63806380
uint3 tgpig[[threadgroup_position_in_grid]],
63816381
uint tiitg[[thread_index_in_threadgroup]],
63826382
uint3 tptg [[threads_per_threadgroup]]) {
@@ -6394,10 +6394,10 @@ kernel void kernel_get_rows_f(
63946394
}
63956395

63966396
kernel void kernel_get_rows_i32(
6397+
constant ggml_metal_kargs_get_rows & args,
63976398
device const void * src0,
63986399
device const void * src1,
63996400
device int32_t * dst,
6400-
constant ggml_metal_kargs_get_rows & args,
64016401
uint3 tgpig[[threadgroup_position_in_grid]],
64026402
uint tiitg[[thread_index_in_threadgroup]],
64036403
uint3 tptg [[threads_per_threadgroup]]) {
@@ -6414,6 +6414,32 @@ kernel void kernel_get_rows_i32(
64146414
}
64156415
}
64166416

6417+
template<typename T>
6418+
kernel void kernel_set_rows_f(
6419+
constant ggml_metal_kargs_set_rows & args,
6420+
device const void * src0,
6421+
device const void * src1,
6422+
device float * dst,
6423+
uint3 tgpig[[threadgroup_position_in_grid]],
6424+
uint tiitg[[thread_index_in_threadgroup]],
6425+
uint3 tptg [[threads_per_threadgroup]]) {
6426+
const int32_t i03 = tgpig.z;
6427+
const int32_t i02 = tgpig.y;
6428+
const int32_t i01 = tgpig.x;
6429+
6430+
const int32_t i12 = i03%args.ne12;
6431+
const int32_t i11 = i02%args.ne11;
6432+
const int32_t i10 = i01;
6433+
6434+
const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
6435+
6436+
device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
6437+
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
6438+
6439+
for (int ind = tiitg; ind < args.nbk0; ind += tptg.x) {
6440+
dst_row[ind] = (T) src_row[ind];
6441+
}
6442+
}
64176443

64186444
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
64196445
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
@@ -6837,6 +6863,18 @@ template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get
68376863
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
68386864
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
68396865

6866+
//
6867+
// set rows
6868+
//
6869+
6870+
typedef decltype(kernel_set_rows_f<float>) set_rows_f_t;
6871+
6872+
template [[host_name("kernel_set_rows_f32")]] kernel set_rows_f_t kernel_set_rows_f<float>;
6873+
template [[host_name("kernel_set_rows_f16")]] kernel set_rows_f_t kernel_set_rows_f<half>;
6874+
#if defined(GGML_METAL_USE_BF16)
6875+
template [[host_name("kernel_set_rows_bf16")]] kernel set_rows_f_t kernel_set_rows_f<bfloat>;
6876+
#endif
6877+
68406878
//
68416879
// matrix-matrix multiplication
68426880
//

0 commit comments

Comments
 (0)