Skip to content

Commit 69e4a32

Browse files
committed
Merge commit 'd4e0d95cf581f50c9a21d06eaecae2dd580076bd' into concedo_experimental
# Conflicts: # .github/workflows/build.yml # common/CMakeLists.txt # ggml/src/CMakeLists.txt # ggml/src/ggml-opencl/CMakeLists.txt # ggml/src/ggml-opencl/ggml-opencl.cpp # ggml/src/ggml-rpc/ggml-rpc.cpp # scripts/sync-ggml.last # tests/CMakeLists.txt
2 parents 33809c9 + d4e0d95 commit 69e4a32

18 files changed

+865
-533
lines changed

convert_hf_to_gguf.py

Lines changed: 5 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -556,8 +556,11 @@ def set_gguf_parameters(self):
556556
logger.info(f"gguf: experts used count = {n_experts_used}")
557557

558558
if (head_dim := self.hparams.get("head_dim")) is not None:
559-
self.gguf_writer.add_key_length(head_dim)
560-
self.gguf_writer.add_value_length(head_dim)
559+
# Workaround for incorrect AutoConfig value for DeepSeekV3 (is set correctly in DeepSeekV2Model class)
560+
# https://github.com/huggingface/transformers/blob/19224c3642705c5b6988c9f5f4251f83323d05ae/src/transformers/models/deepseek_v3/configuration_deepseek_v3.py#L210
561+
if self.hparams.get("model_type") != "deepseek_v3":
562+
self.gguf_writer.add_key_length(head_dim)
563+
self.gguf_writer.add_value_length(head_dim)
561564

562565
self.gguf_writer.add_file_type(self.ftype)
563566
logger.info(f"gguf: file type = {self.ftype}")
@@ -4798,25 +4801,6 @@ def prepare_tensors(self):
47984801
class JinaBertV2Model(BertModel):
47994802
model_arch = gguf.MODEL_ARCH.JINA_BERT_V2
48004803

4801-
def __init__(self, *args, **kwargs):
4802-
super().__init__(*args, **kwargs)
4803-
self.intermediate_size = self.hparams["intermediate_size"]
4804-
4805-
def get_tensors(self):
4806-
for name, data in super().get_tensors():
4807-
if 'gated_layer' in name:
4808-
d1 = data[:self.intermediate_size, :]
4809-
name1 = name.replace('gated_layers', 'gated_layers_w')
4810-
name1 = name1.replace('up_gated_layer', 'gated_layers_v')
4811-
d2 = data[self.intermediate_size:, :]
4812-
name2 = name.replace('gated_layers', 'gated_layers_v')
4813-
name2 = name2.replace('up_gated_layer', 'gated_layers_w')
4814-
yield name1, d1
4815-
yield name2, d2
4816-
continue
4817-
4818-
yield name, data
4819-
48204804
def set_vocab(self):
48214805
tokenizer_class = 'BertTokenizer'
48224806
with open(self.dir_model / "tokenizer_config.json", "r", encoding="utf-8") as f:
@@ -4832,14 +4816,6 @@ def set_vocab(self):
48324816
self.gguf_writer.add_add_bos_token(True)
48334817
self.gguf_writer.add_add_eos_token(True)
48344818

4835-
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
4836-
# if name starts with "bert.", remove the prefix
4837-
# e.g. https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
4838-
if name.startswith("bert."):
4839-
name = name[5:]
4840-
4841-
return super().modify_tensors(data_torch, name, bid)
4842-
48434819

48444820
@ModelBase.register("OpenELMForCausalLM")
48454821
class OpenELMModel(TextModel):

ggml/src/ggml-cpu/ggml-cpu-impl.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,11 +518,14 @@ void ggml_barrier(struct ggml_threadpool * tp);
518518
#elif defined(__GNUC__)
519519
// GCC/Clang on *nix
520520
# define GGML_WEAK_ALIAS(name, alias) GGML_DO_PRAGMA(weak name = alias) // NOLINT
521-
#elif defined(_MSC_VER) && defined (_WIN64)
521+
#elif defined(_MSC_VER) && defined(_WIN64)
522522
// MSVC
523523
// Note: C name mangling varies across different calling conventions
524524
// see https://learn.microsoft.com/en-us/cpp/build/reference/decorated-names?view=msvc-170
525525
# define GGML_WEAK_ALIAS(name, alias) GGML_DO_PRAGMA(comment(linker, "/alternatename:" #name "=" #alias))
526+
#elif defined(_MSC_VER) && defined(WIN32)
527+
// ref: https://github.com/ggml-org/whisper.cpp/pull/3239#issuecomment-2958224591
528+
# define GGML_WEAK_ALIAS(name, alias) GGML_DO_PRAGMA(comment(linker, "/alternatename:_" #name "=_" #alias))
526529
#else
527530
# error "Unsupported compiler for GGML_WEAK_ALIAS"
528531
#endif

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 54 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -3333,8 +3333,6 @@ kernel void kernel_flash_attn_ext(
33333333

33343334
threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
33353335
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
3336-
threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation
3337-
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t
33383336
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix
33393337

33403338
threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
@@ -3548,20 +3546,20 @@ kernel void kernel_flash_attn_ext(
35483546

35493547
// O = diag(ms)*O
35503548
{
3551-
s8x8_t mm;
3552-
simdgroup_load(mm, ss + 2*C, TS, 0, false);
3549+
s8x8_t ms;
3550+
simdgroup_load(ms, ss + 2*C, TS, 0, false);
35533551

35543552
#pragma unroll(DV8)
35553553
for (short i = 0; i < DV8; ++i) {
3556-
simdgroup_multiply(lo[i], mm, lo[i]);
3554+
simdgroup_multiply(lo[i], ms, lo[i]);
35573555
}
35583556
}
35593557

35603558
// O = O + (Q*K^T)*V
35613559
{
35623560
for (short cc = 0; cc < C/8; ++cc) {
3563-
s8x8_t ms;
3564-
simdgroup_load(ms, ss + 8*cc, TS, 0, false);
3561+
s8x8_t vs;
3562+
simdgroup_load(vs, ss + 8*cc, TS, 0, false);
35653563

35663564
if (is_same<vd4x4_t, v4x4_t>::value) {
35673565
// we can read directly from global memory
@@ -3572,7 +3570,7 @@ kernel void kernel_flash_attn_ext(
35723570
v8x8_t mv;
35733571
simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20
35743572

3575-
simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
3573+
simdgroup_multiply_accumulate(lo[i], vs, mv, lo[i]);
35763574
}
35773575
} else {
35783576
for (short ii = 0; ii < DV16; ii += 4) {
@@ -3593,10 +3591,10 @@ kernel void kernel_flash_attn_ext(
35933591
v8x8_t mv;
35943592

35953593
simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
3596-
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
3594+
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]);
35973595

35983596
simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
3599-
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
3597+
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]);
36003598
}
36013599
} else {
36023600
if (ii + tx < DV16) {
@@ -3611,10 +3609,10 @@ kernel void kernel_flash_attn_ext(
36113609
v8x8_t mv;
36123610

36133611
simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
3614-
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
3612+
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]);
36153613

36163614
simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
3617-
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
3615+
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]);
36183616
}
36193617
}
36203618
}
@@ -3624,83 +3622,80 @@ kernel void kernel_flash_attn_ext(
36243622
}
36253623

36263624
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
3627-
for (short j = 0; j < Q; ++j) {
3628-
if (tiisg == 0) {
3629-
ss[j*TS + 0] = S[j];
3630-
ss[j*TS + 1] = M[j];
3631-
}
3625+
for (short j = tiisg; j < Q; j += NW) {
3626+
ss[j*TS + 0] = S[j];
3627+
ss[j*TS + 1] = M[j];
36323628
}
36333629
}
36343630

3635-
// reduce the warps sequentially
3636-
for (ushort sg = 1; sg < nsg; ++sg) {
3637-
threadgroup_barrier(mem_flags::mem_threadgroup);
3631+
threadgroup_barrier(mem_flags::mem_threadgroup);
36383632

3639-
// each simdgroup stores its output to shared memory, reusing sq
3640-
if (sgitg == sg) {
3641-
for (short i = 0; i < DV8; ++i) {
3642-
simdgroup_store(lo[i], so + i*8, DV, 0, false);
3643-
}
3633+
threadgroup float * so = (threadgroup float *) (shmem_f16 + 0*DK); // reuse query data for accumulation
3634+
threadgroup float4 * so4 = (threadgroup float4 *) (shmem_f16 + 0*DK);
3635+
3636+
// store result to shared memory in F32
3637+
if (sgitg == 0) {
3638+
for (short i = 0; i < DV8; ++i) {
3639+
//simdgroup_store(lo[i], so + i*8, DV, 0, false);
3640+
simdgroup_float8x8 t(1.0f);
3641+
simdgroup_multiply(t, lo[i], t);
3642+
simdgroup_store(t, so + i*8, DV, 0, false);
36443643
}
3644+
}
36453645

3646-
threadgroup_barrier(mem_flags::mem_threadgroup);
3646+
threadgroup_barrier(mem_flags::mem_threadgroup);
36473647

3648-
// the first simdgroup accumulates the results from the other simdgroups
3649-
if (sgitg == 0) {
3650-
for (short j = 0; j < Q; ++j) {
3651-
const float S0 = ss[j*TS + 0];
3652-
const float S1 = ss[j*TS + sg*SH + 0];
3648+
// reduce the warps sequentially
3649+
for (ushort sg = 1; sg < nsg; ++sg) {
3650+
if (sgitg == sg) {
3651+
for (short j = tiisg; j < Q; j += NW) {
3652+
const float S0 = ss[j*TS - 1*SH + 0];
3653+
const float S1 = ss[j*TS + 0];
36533654

3654-
const float M0 = ss[j*TS + 1];
3655-
const float M1 = ss[j*TS + sg*SH + 1];
3655+
const float M0 = ss[j*TS - 1*SH + 1];
3656+
const float M1 = ss[j*TS + 1];
36563657

36573658
const float M = max(M0, M1);
36583659

3659-
const float ms0 = exp(M0 - M);
3660-
const float ms1 = exp(M1 - M);
3660+
float ms0 = exp(M0 - M);
3661+
float ms1 = exp(M1 - M);
36613662

36623663
const float S = S0*ms0 + S1*ms1;
36633664

3664-
if (tiisg == 0) {
3665-
ss[j*TS + 0] = S;
3666-
ss[j*TS + 1] = M;
3665+
ss[j*TS + 0] = S;
3666+
ss[j*TS + 1] = M;
36673667

3668-
ss[j*TS + 2*C + j ] = ms0;
3669-
ss[j*TS + 2*C + j + sg*SH] = ms1;
3670-
}
3668+
ss[j*TS + 2*C + j - 1*SH] = ms0;
3669+
ss[j*TS + 2*C + j ] = ms1;
36713670
}
36723671

3672+
//simdgroup_barrier(mem_flags::mem_threadgroup);
3673+
36733674
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
36743675
{
36753676
s8x8_t ms0;
36763677
s8x8_t ms1;
36773678

3678-
simdgroup_load(ms0, ss + 2*C, TS, 0, false);
3679-
simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
3679+
simdgroup_load(ms0, ss + 2*C - 1*SH, TS, 0, false);
3680+
simdgroup_load(ms1, ss + 2*C, TS, 0, false);
36803681

36813682
#pragma unroll(DV8)
36823683
for (short i = 0; i < DV8; ++i) {
3683-
o8x8_t t;
3684+
simdgroup_float8x8 t;
36843685

36853686
simdgroup_load (t, so + i*8, DV, 0, false);
3686-
simdgroup_multiply(t, ms1, t);
3687+
simdgroup_multiply(t, ms0, t);
36873688

3688-
simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
3689+
simdgroup_multiply_accumulate(t, ms1, lo[i], t);
3690+
simdgroup_store(t, so + i*8, DV, 0, false);
36893691
}
36903692
}
36913693
}
3692-
}
36933694

3694-
// store result to shared memory (reuse sq)
3695-
if (sgitg == 0) {
3696-
for (short i = 0; i < DV8; ++i) {
3697-
simdgroup_store(lo[i], so + i*8, DV, 0, false);
3698-
}
3695+
threadgroup_barrier(mem_flags::mem_threadgroup);
36993696
}
37003697

3701-
threadgroup_barrier(mem_flags::mem_threadgroup);
3702-
3703-
threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*Q*DK);
3698+
threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*(nsg-1)*SH + 2*Q*DK);
37043699

37053700
// final rescale with 1/S and store to global memory
37063701
for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) {
@@ -3723,17 +3718,17 @@ kernel void kernel_flash_attn_ext(
37233718
half, half4x4, simdgroup_half8x8, \
37243719
float, simdgroup_float8x8, \
37253720
float, simdgroup_float8x8, \
3726-
float, float4, simdgroup_float8x8
3727-
//half, half4, simdgroup_half8x8
3721+
half, half4, simdgroup_half8x8
3722+
//float, float4, simdgroup_float8x8
37283723

37293724
#define FA_TYPES_BF \
37303725
bfloat, bfloat4, simdgroup_bfloat8x8, \
37313726
bfloat, bfloat4x4, simdgroup_bfloat8x8, \
37323727
bfloat, bfloat4x4, simdgroup_bfloat8x8, \
37333728
float, simdgroup_float8x8, \
37343729
float, simdgroup_float8x8, \
3735-
float, float4, simdgroup_float8x8
3736-
//half, half4, simdgroup_half8x8
3730+
half, half4, simdgroup_half8x8
3731+
//float, float4, simdgroup_float8x8
37373732

37383733
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
37393734

0 commit comments

Comments
 (0)