Skip to content

Commit 42b5324

Browse files
committed
ggml : add get_rows
1 parent a977c11 commit 42b5324

File tree

2 files changed

+49
-83
lines changed

2 files changed

+49
-83
lines changed

ggml/src/ggml-metal.m

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
6060
GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
6161
GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
62+
GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16,
6263
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
6364
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
6465
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
@@ -515,6 +516,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
515516
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
516517
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
517518
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
519+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, true);
518520
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
519521
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
520522
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
@@ -736,8 +738,10 @@ static void ggml_metal_free(struct ggml_metal_context * ctx) {
736738

737739
static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const struct ggml_tensor * op) {
738740
for (size_t i = 0, n = 3; i < n; ++i) {
739-
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
740-
return false;
741+
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16 &&
742+
op->op != GGML_OP_GET_ROWS) {
743+
printf("op = %s, src[%zu] = %s\n", ggml_op_name(op->op), i, ggml_type_name(op->src[i]->type));
744+
GGML_ASSERT(false);
741745
}
742746
}
743747

@@ -837,7 +841,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
837841
case GGML_OP_DIAG_MASK_INF:
838842
case GGML_OP_GET_ROWS:
839843
{
840-
return op->src[0]->type != GGML_TYPE_BF16 && op->ne[3] == 1;
844+
return op->ne[3] == 1;
841845
}
842846
default:
843847
return false;
@@ -2162,6 +2166,7 @@ static enum ggml_status ggml_metal_graph_compute(
21622166
switch (src0->type) {
21632167
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break;
21642168
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break;
2169+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16 ].pipeline; break;
21652170
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break;
21662171
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break;
21672172
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;

ggml/src/ggml-metal.metal

Lines changed: 41 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -5730,9 +5730,9 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4
57305730
}
57315731

57325732
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
5733-
kernel void kernel_get_rows(
5733+
kernel void kernel_get_rows_q(
57345734
device const void * src0,
5735-
device const char * src1,
5735+
device const void * src1,
57365736
device float * dst,
57375737
constant int64_t & ne00,
57385738
constant uint64_t & nb01,
@@ -5745,55 +5745,24 @@ kernel void kernel_get_rows(
57455745
uint3 tgpig[[threadgroup_position_in_grid]],
57465746
uint tiitg[[thread_index_in_threadgroup]],
57475747
uint3 tptg [[threads_per_threadgroup]]) {
5748-
//const int64_t i = tgpig;
5749-
//const int64_t r = ((device int32_t *) src1)[i];
5750-
57515748
const int64_t i10 = tgpig.x;
57525749
const int64_t i11 = tgpig.y;
57535750

5754-
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
5751+
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
57555752

57565753
const int64_t i02 = i11;
57575754

57585755
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
57595756
float4x4 temp;
5760-
dequantize_func(
5761-
((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
5757+
dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
57625758
*(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
57635759
}
57645760
}
57655761

5766-
kernel void kernel_get_rows_f32(
5767-
device const void * src0,
5768-
device const char * src1,
5769-
device float * dst,
5770-
constant int64_t & ne00,
5771-
constant uint64_t & nb01,
5772-
constant uint64_t & nb02,
5773-
constant int64_t & ne10,
5774-
constant uint64_t & nb10,
5775-
constant uint64_t & nb11,
5776-
constant uint64_t & nb1,
5777-
constant uint64_t & nb2,
5778-
uint3 tgpig[[threadgroup_position_in_grid]],
5779-
uint tiitg[[thread_index_in_threadgroup]],
5780-
uint3 tptg [[threads_per_threadgroup]]) {
5781-
const int64_t i10 = tgpig.x;
5782-
const int64_t i11 = tgpig.y;
5783-
5784-
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
5785-
5786-
const int64_t i02 = i11;
5787-
5788-
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
5789-
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
5790-
((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
5791-
}
5792-
}
5793-
5794-
kernel void kernel_get_rows_f16(
5762+
template<typename T>
5763+
kernel void kernel_get_rows_f(
57955764
device const void * src0,
5796-
device const char * src1,
5765+
device const void * src1,
57975766
device float * dst,
57985767
constant int64_t & ne00,
57995768
constant uint64_t & nb01,
@@ -5809,19 +5778,19 @@ kernel void kernel_get_rows_f16(
58095778
const int64_t i10 = tgpig.x;
58105779
const int64_t i11 = tgpig.y;
58115780

5812-
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
5781+
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
58135782

58145783
const int64_t i02 = i11;
58155784

58165785
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
5817-
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
5818-
((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
5786+
(( device float *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
5787+
((const device T *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
58195788
}
58205789
}
58215790

58225791
kernel void kernel_get_rows_i32(
58235792
device const void * src0,
5824-
device const char * src1,
5793+
device const void * src1,
58255794
device int32_t * dst,
58265795
constant int64_t & ne00,
58275796
constant uint64_t & nb01,
@@ -5837,13 +5806,13 @@ kernel void kernel_get_rows_i32(
58375806
const int64_t i10 = tgpig.x;
58385807
const int64_t i11 = tgpig.y;
58395808

5840-
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
5809+
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
58415810

58425811
const int64_t i02 = i11;
58435812

58445813
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
5845-
((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
5846-
((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
5814+
(( device int32_t *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
5815+
((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
58475816
}
58485817
}
58495818

@@ -6237,41 +6206,33 @@ kernel void kernel_mul_mm_id(
62376206
// get rows
62386207
//
62396208

6240-
typedef void (get_rows_t)(
6241-
device const void * src0,
6242-
device const char * src1,
6243-
device float * dst,
6244-
constant int64_t & ne00,
6245-
constant uint64_t & nb01,
6246-
constant uint64_t & nb02,
6247-
constant int64_t & ne10,
6248-
constant uint64_t & nb10,
6249-
constant uint64_t & nb11,
6250-
constant uint64_t & nb1,
6251-
constant uint64_t & nb2,
6252-
uint3, uint, uint3);
6253-
6254-
//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
6255-
//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
6256-
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
6257-
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
6258-
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
6259-
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
6260-
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
6261-
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
6262-
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
6263-
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
6264-
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
6265-
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
6266-
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
6267-
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
6268-
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
6269-
template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>;
6270-
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
6271-
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
6272-
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>;
6273-
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
6274-
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6209+
typedef decltype(kernel_get_rows_f<float>) get_rows_f_t;
6210+
6211+
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float>;
6212+
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half>;
6213+
template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat>;
6214+
6215+
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
6216+
6217+
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
6218+
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
6219+
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
6220+
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
6221+
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
6222+
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
6223+
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
6224+
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
6225+
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_K, QK_NL, dequantize_q5_K>;
6226+
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q6_K, QK_NL, dequantize_q6_K>;
6227+
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
6228+
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
6229+
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
6230+
template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_s, QK_NL, dequantize_iq3_s>;
6231+
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_s, QK_NL, dequantize_iq2_s>;
6232+
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_s, QK_NL, dequantize_iq1_s>;
6233+
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_m, QK_NL, dequantize_iq1_m>;
6234+
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
6235+
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
62756236

62766237
//
62776238
// matrix-matrix multiplication

0 commit comments

Comments
 (0)