Skip to content

Commit 4d1783a

Browse files
committed
cont
1 parent 5e09948 commit 4d1783a

File tree

1 file changed

+20
-10
lines changed

1 file changed

+20
-10
lines changed

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,9 @@ __embed_ggml-common.h__
1212
#define GGML_METAL_USE_METAL4
1313

1414
#ifdef GGML_METAL_USE_METAL4
15-
#include <metal_stdlib>
1615
#include <metal_tensor>
1716

1817
#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h>
19-
20-
using namespace metal;
21-
using namespace mpp::tensor_ops;
2218
#endif
2319

2420
using namespace metal;
@@ -1754,7 +1750,7 @@ kernel void kernel_op_sum_f32(
17541750

17551751
float sumf = 0;
17561752

1757-
for (int64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) {
1753+
for (uint64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) {
17581754
sumf += src0[i0];
17591755
}
17601756

@@ -5457,6 +5453,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_at
54575453

54585454
#undef FA_TYPES
54595455
#undef FA_TYPES_BF
5456+
#undef FA_TYPES_F32
54605457

54615458
constant bool FC_flash_attn_ext_vec_has_mask [[function_constant(FC_FLASH_ATTN_EXT_VEC + 0)]];
54625459
constant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1)]];
@@ -6078,6 +6075,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flas
60786075
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 576, 512, 2>;
60796076

60806077
#undef FA_TYPES
6078+
#undef FA_TYPES_F32
60816079

60826080
constant int32_t FC_flash_attn_ext_vec_reduce_DV [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 0)]];
60836081
constant int32_t FC_flash_attn_ext_vec_reduce_NWG [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 1)]];
@@ -8211,9 +8209,9 @@ kernel void kernel_mul_mm(
82118209
auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK, NR0));
82128210
auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK ));
82138211

8214-
constexpr auto desc = matmul2d_descriptor(NR1, NR0, NK, false, true, false, matmul2d_descriptor::mode::multiply_accumulate);
8212+
constexpr auto desc = mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate);
82158213

8216-
matmul2d<desc, execution_simdgroups<4>> mm;
8214+
mpp::tensor_ops::matmul2d<desc, execution_simdgroups<4>> mm;
82178215

82188216
auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>();
82198217
#endif
@@ -8359,16 +8357,28 @@ kernel void kernel_mul_mm(
83598357
}
83608358
}
83618359

8362-
for (short i = 0; i < 8; ++i) {
8360+
if (FC_mul_mm_bc_inp) {
8361+
for (short i = 0; i < 8; ++i) {
8362+
const short sx = (tiitg%NL1);
8363+
const short sy = (tiitg/NL1)/8;
8364+
8365+
const short lx = i;
8366+
const short ly = (tiitg/NL1)%8;
8367+
//const short lx = (tiitg/NL1)%8;
8368+
//const short ly = i;
8369+
8370+
*(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
8371+
}
8372+
} else {
83638373
const short sx = (tiitg%NL1);
83648374
const short sy = (tiitg/NL1)/8;
83658375

8366-
const short lx = i;
8376+
//const short lx = i;
83678377
const short ly = (tiitg/NL1)%8;
83688378
//const short lx = (tiitg/NL1)%8;
83698379
//const short ly = i;
83708380

8371-
*(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
8381+
*(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y));
83728382
}
83738383

83748384
il = (il + 2 < nl) ? il + 2 : il % 2;

0 commit comments

Comments
 (0)