Skip to content

Commit 19116a4

Browse files
committed
Merge branch 'master' into gg/clip-fa
2 parents a4b54f2 + 2f68ce7 commit 19116a4

File tree

155 files changed

+16348
-14480
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

155 files changed

+16348
-14480
lines changed

.github/workflows/docker.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ jobs:
4040
# https://github.com/ggml-org/llama.cpp/issues/11888
4141
#- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64,linux/arm64", full: true, light: true, server: true, free_disk_space: false }
4242
- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04" }
43-
- { tag: "cuda", dockerfile: ".devops/cuda.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04" }
43+
- { tag: "cuda", dockerfile: ".devops/cuda.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
4444
- { tag: "musa", dockerfile: ".devops/musa.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
4545
- { tag: "intel", dockerfile: ".devops/intel.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
4646
- { tag: "vulkan", dockerfile: ".devops/vulkan.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04" }

CODEOWNERS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
/src/llama-model-loader.* @slaren
9090
/src/llama-model.* @CISC
9191
/src/llama-vocab.* @CISC
92+
/src/models/ @CISC
9293
/tests/ @ggerganov
9394
/tests/test-backend-ops.cpp @slaren
9495
/tests/test-thread-safety.cpp @slaren

common/arg.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2030,7 +2030,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
20302030
params.system_prompt.pop_back();
20312031
}
20322032
}
2033-
).set_examples({LLAMA_EXAMPLE_MAIN}));
2033+
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_DIFFUSION}));
20342034
add_opt(common_arg(
20352035
{"--in-file"}, "FNAME",
20362036
"an input file (repeat to specify multiple files)",

convert_hf_to_gguf.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,6 +1054,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
10541054
if chkhsh == "53e325976a6e142379c19b09afcae354f2f496f147afa8f9e189a33fe4e3024e":
10551055
# ref: https://huggingface.co/ibm-granite/granite-docling-258M
10561056
res = "granite-docling"
1057+
if chkhsh == "f4f37b6c8eb9ea29b3eac6bb8c8487c5ab7885f8d8022e67edc1c68ce8403e95":
1058+
# ref: https://huggingface.co/MiniMaxAI/MiniMax-M2
1059+
res = "minimax-m2"
10571060

10581061
if res is None:
10591062
logger.warning("\n")
@@ -7126,6 +7129,64 @@ def prepare_tensors(self):
71267129
raise ValueError(f"Unprocessed experts: {experts}")
71277130

71287131

7132+
@ModelBase.register("MiniMaxM2ForCausalLM")
7133+
class MiniMaxM2Model(TextModel):
7134+
model_arch = gguf.MODEL_ARCH.MINIMAXM2
7135+
_experts_cache: dict[int, dict[str, Tensor]] = {}
7136+
7137+
def __init__(self, *args, **kwargs):
7138+
super().__init__(*args, **kwargs)
7139+
self.hparams["num_experts"] = self.hparams["num_local_experts"]
7140+
7141+
def set_gguf_parameters(self):
7142+
super().set_gguf_parameters()
7143+
if self.hparams["scoring_func"] == "sigmoid":
7144+
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
7145+
elif self.hparams["scoring_func"] == "softmax":
7146+
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
7147+
else:
7148+
raise ValueError(f"Unsupported scoring_func value: {self.hparams['scoring_func']}")
7149+
7150+
self.gguf_writer.add_expert_feed_forward_length(self.find_hparam(["intermediate_size"]))
7151+
self.gguf_writer.add_rope_dimension_count(self.find_hparam(["rotary_dim"]))
7152+
7153+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
7154+
if name.endswith("e_score_correction_bias"):
7155+
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
7156+
7157+
# merge expert weights
7158+
if 'experts' in name:
7159+
n_experts = self.hparams["num_experts"]
7160+
assert bid is not None
7161+
7162+
expert_cache = self._experts_cache.setdefault(bid, {})
7163+
expert_cache[name] = data_torch
7164+
expert_weights = ["w1", "w2", "w3"]
7165+
7166+
# not enough expert weights to merge
7167+
if len(expert_cache) < n_experts * len(expert_weights):
7168+
return []
7169+
7170+
tensors: list[tuple[str, Tensor]] = []
7171+
for w_name in expert_weights:
7172+
datas: list[Tensor] = []
7173+
7174+
for xid in range(n_experts):
7175+
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{w_name}.weight"
7176+
datas.append(expert_cache[ename])
7177+
del expert_cache[ename]
7178+
7179+
data_torch = torch.stack(datas, dim=0)
7180+
merged_name = f"model.layers.{bid}.block_sparse_moe.experts.{w_name}.weight"
7181+
new_name = self.map_tensor_name(merged_name)
7182+
tensors.append((new_name, data_torch))
7183+
7184+
del self._experts_cache[bid]
7185+
return tensors
7186+
7187+
return super().modify_tensors(data_torch, name, bid)
7188+
7189+
71297190
@ModelBase.register("Dots1ForCausalLM")
71307191
class Dots1Model(Qwen2MoeModel):
71317192
model_arch = gguf.MODEL_ARCH.DOTS1

convert_hf_to_gguf_update.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ class TOKENIZER_TYPE(IntEnum):
141141
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
142142
{"name": "bailingmoe2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", },
143143
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
144+
{"name": "minimax-m2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", },
144145
]
145146

146147
# some models are known to be broken upstream, so we will skip them as exceptions
@@ -435,7 +436,7 @@ def get_vocab_base_pre(self, tokenizer) -> str:
435436
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False)
436437
else:
437438
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
438-
except OSError as e:
439+
except (OSError, TypeError) as e:
439440
logger.error(f"Failed to load tokenizer for model {name}. Error: {e}")
440441
continue # Skip this model and continue with the next one in the loop
441442

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1613,13 +1613,8 @@ static void ggml_compute_forward_mul_mat_id(
16131613
chunk_size = 64;
16141614
}
16151615

1616-
#if defined(__aarch64__)
1617-
// disable for ARM
1618-
const bool disable_chunking = true;
1619-
#else
16201616
// disable for NUMA
16211617
const bool disable_chunking = ggml_is_numa();
1622-
#endif // defined(__aarch64__)
16231618

16241619
int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
16251620
int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;

ggml/src/ggml-cpu/repack.cpp

Lines changed: 57 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1600,6 +1600,32 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
16001600
return false;
16011601
}
16021602

1603+
void forward_mul_mat_one_chunk(ggml_compute_params * params, ggml_tensor * op, int64_t src0_start, int64_t src0_end) {
1604+
const ggml_tensor * src0 = op->src[0];
1605+
const ggml_tensor * src1 = op->src[1];
1606+
ggml_tensor * dst = op;
1607+
1608+
GGML_TENSOR_BINARY_OP_LOCALS
1609+
1610+
const void * src1_wdata = params->wdata;
1611+
const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10);
1612+
1613+
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
1614+
if (ne11 > 3) {
1615+
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1616+
(float *) ((char *) dst->data) + src0_start, ne01,
1617+
(const char *) src0->data + src0_start * nb01,
1618+
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
1619+
}
1620+
for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
1621+
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1622+
(float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
1623+
(const char *) src0->data + src0_start * nb01,
1624+
(const char *) src1_wdata + (src1_col_stride * iter), 1,
1625+
src0_end - src0_start);
1626+
}
1627+
}
1628+
16031629
void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) {
16041630
const ggml_tensor * src0 = op->src[0];
16051631
const ggml_tensor * src1 = op->src[1];
@@ -1643,31 +1669,41 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
16431669
from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
16441670
}
16451671

1646-
ggml_barrier(params->threadpool);
1672+
// disable for NUMA
1673+
const bool disable_chunking = ggml_is_numa();
16471674

1648-
const void * src1_wdata = params->wdata;
1649-
const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10);
1650-
int64_t src0_start = (ith * ne01) / nth;
1651-
int64_t src0_end = ((ith + 1) * ne01) / nth;
1652-
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
1653-
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
1654-
if (src0_start >= src0_end) {
1655-
return;
1675+
// 4x chunks per thread
1676+
int64_t nr = ggml_nrows(op->src[0]);
1677+
int nth_scaled = nth * 4;
1678+
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
1679+
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
1680+
1681+
if (nth == 1 || nchunk < nth || disable_chunking) {
1682+
nchunk = nth;
16561683
}
16571684

1658-
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
1659-
if (ne11 > 3) {
1660-
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1661-
(float *) ((char *) dst->data) + src0_start, ne01,
1662-
(const char *) src0->data + src0_start * nb01,
1663-
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
1685+
if (ith == 0) {
1686+
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
1687+
ggml_threadpool_chunk_set(params->threadpool, nth);
16641688
}
1665-
for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
1666-
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1667-
(float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
1668-
(const char *) src0->data + src0_start * nb01,
1669-
(const char *) src1_wdata + (src1_col_stride * iter), 1,
1670-
src0_end - src0_start);
1689+
1690+
ggml_barrier(params->threadpool);
1691+
1692+
// The first chunk comes from our thread_id, the rest will get auto-assigned.
1693+
int current_chunk = ith;
1694+
1695+
while (current_chunk < nchunk) {
1696+
int64_t src0_start = (current_chunk * ne01) / nchunk;
1697+
int64_t src0_end = ((current_chunk + 1) * ne01) / nchunk;
1698+
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
1699+
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
1700+
if (src0_start >= src0_end) {
1701+
break;
1702+
}
1703+
1704+
forward_mul_mat_one_chunk(params, dst, src0_start, src0_end);
1705+
1706+
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
16711707
}
16721708
}
16731709

ggml/src/ggml-cuda/common.cuh

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,11 @@ static const char * cu_get_error_str(CUresult err) {
224224
#define AMD_MFMA_AVAILABLE
225225
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
226226

227+
// The Volta instructions are in principle available on Turing or newer but they are effectively unusable:
228+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
229+
#define VOLTA_MMA_AVAILABLE
230+
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
231+
227232
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
228233
#define TURING_MMA_AVAILABLE
229234
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
@@ -278,7 +283,10 @@ static bool amd_mfma_available(const int cc) {
278283
#endif //!defined(GGML_HIP_NO_MMQ_MFMA)
279284
}
280285

281-
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
286+
static bool volta_mma_available(const int cc) {
287+
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA;
288+
}
289+
282290
static bool turing_mma_available(const int cc) {
283291
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
284292
}

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "ggml-cuda/mmq.cuh"
2828
#include "ggml-cuda/mmvf.cuh"
2929
#include "ggml-cuda/mmvq.cuh"
30+
#include "ggml-cuda/moe-expert-reduce.cuh"
3031
#include "ggml-cuda/norm.cuh"
3132
#include "ggml-cuda/opt-step-adamw.cuh"
3233
#include "ggml-cuda/opt-step-sgd.cuh"
@@ -3169,6 +3170,31 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
31693170
continue;
31703171
}
31713172

3173+
if (node->op == GGML_OP_MUL) {
3174+
int current_node = i + 1;
3175+
int num_views = 0;
3176+
int num_adds = 0;
3177+
while (current_node < cgraph->n_nodes && cgraph->nodes[current_node]->op == GGML_OP_VIEW) {
3178+
num_views++;
3179+
current_node++;
3180+
}
3181+
3182+
while (current_node < cgraph->n_nodes && cgraph->nodes[current_node]->op == GGML_OP_ADD &&
3183+
num_adds < num_views - 1) {
3184+
num_adds++;
3185+
current_node++;
3186+
}
3187+
3188+
if (num_adds == num_views - 1 && num_views > 0) {
3189+
ggml_tensor * dst_node = cgraph->nodes[current_node - 1];
3190+
if (ggml_cuda_should_use_moe_expert_reduce(cgraph, i, current_node)) {
3191+
ggml_cuda_op_moe_expert_reduce(*cuda_ctx, node->src[0], node->src[1], dst_node);
3192+
i += num_views + num_adds;
3193+
continue;
3194+
}
3195+
}
3196+
}
3197+
31723198
if (node->op == GGML_OP_ADD) {
31733199
int n_fuse = 0;
31743200
ggml_op ops[8];

0 commit comments

Comments
 (0)