Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2437,7 +2437,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--cpu-moe", "-cmoe"},
"keep all Mixture of Experts (MoE) weights in the CPU",
[](common_params & params) {
params.tensor_buft_overrides.push_back({"\\.ffn_(up|down|gate)_exps", ggml_backend_cpu_buffer_type()});
params.tensor_buft_overrides.push_back({"\\.ffn_(up|down|gate)_(ch|)exps", ggml_backend_cpu_buffer_type()});
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The regexs are now in common.h

}
).set_env("LLAMA_ARG_CPU_MOE"));
add_opt(common_arg(
Expand All @@ -2450,7 +2450,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
for (int i = 0; i < value; ++i) {
// keep strings alive and avoid leaking memory by storing them in a static vector
static std::list<std::string> buft_overrides;
buft_overrides.push_back(string_format("blk\\.%d\\.ffn_(up|down|gate)_exps", i));
buft_overrides.push_back(string_format("blk\\.%d\\.ffn_(up|down|gate)_(ch|)exps", i));
params.tensor_buft_overrides.push_back({buft_overrides.back().c_str(), ggml_backend_cpu_buffer_type()});
}
}
Expand Down
115 changes: 115 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7820,6 +7820,121 @@ def prepare_tensors(self):
raise ValueError(f"Unprocessed experts: {experts}")


@ModelBase.register("GroveMoeForCausalLM", "modeling_grove_moe.GroveMoeForCausalLM")
class GroveMoeModel(TextModel):
model_arch = gguf.MODEL_ARCH.GROVEMOE

def set_gguf_parameters(self):
super().set_gguf_parameters()
if (n_experts := self.hparams.get("num_experts")) is not None:
self.gguf_writer.add_expert_count(n_experts)
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}")
# FIXME?: Hardcoded https://huggingface.co/inclusionAI/GroveMoE-Inst/blob/c4c69e5970d18907b5e6ddccdfd55176fe292df1/modeling_grove_moe.py#L299
self.gguf_writer.add_expert_chunk_feed_forward_length(self.hparams.get("head_dim") or 128)
# FIXME?: Hardcoded https://huggingface.co/inclusionAI/GroveMoE-Inst/blob/c4c69e5970d18907b5e6ddccdfd55176fe292df1/modeling_grove_moe.py#L298
self.gguf_writer.add_experts_per_group(2)
# FIXME?: Hardcoded https://huggingface.co/inclusionAI/GroveMoE-Inst/blob/c4c69e5970d18907b5e6ddccdfd55176fe292df1/modeling_grove_moe.py#L376
self.gguf_writer.add_expert_group_scale(0.05)
# YaRN is not enabled by default
# To enable it, please refer to this guide: https://huggingface.co/Qwen/Qwen3-30B-A3B#processing-long-texts
rope_scaling = self.hparams.get("rope_scaling") or {}
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])

_experts: list[dict[str, Tensor]] | None = None
_chunk_experts: list[dict[str, Tensor]] | None = None

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name.endswith(".expert_bias"):
# FIXME?: Unused https://huggingface.co/inclusionAI/GroveMoE-Inst/blob/c4c69e5970d18907b5e6ddccdfd55176fe292df1/modeling_grove_moe.py#L303
return []

# process the experts separately
if name.find("chunk_experts") != -1:
n_experts = self.hparams["num_experts"] // 2 # see add_experts_per_group
assert bid is not None

if self._chunk_experts is None:
self._chunk_experts = [{} for _ in range(self.block_count)]

self._chunk_experts[bid][name] = data_torch

if len(self._chunk_experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []

# merge the experts into a single 3d tensor
for w_name in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []

for xid in range(n_experts):
ename = f"model.layers.{bid}.mlp.chunk_experts.{xid}.{w_name}.weight"
datas.append(self._chunk_experts[bid][ename])
del self._chunk_experts[bid][ename]

data_torch = torch.stack(datas, dim=0)

merged_name = f"model.layers.{bid}.mlp.chunk_experts.{w_name}.weight"

new_name = self.map_tensor_name(merged_name)

tensors.append((new_name, data_torch))
return tensors
else:
return []
elif name.find("experts") != -1:
n_experts = self.hparams["num_experts"]
assert bid is not None

if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

self._experts[bid][name] = data_torch

if len(self._experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []

# merge the experts into a single 3d tensor
for w_name in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []

for xid in range(n_experts):
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]

data_torch = torch.stack(datas, dim=0)

merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"

new_name = self.map_tensor_name(merged_name)

tensors.append((new_name, data_torch))
return tensors
else:
return []

return [(self.map_tensor_name(name), data_torch)]

def prepare_tensors(self):
super().prepare_tensors()

if self._chunk_experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
chunk_experts = [k for d in self._chunk_experts for k in d.keys()]
if len(chunk_experts) > 0:
raise ValueError(f"Unprocessed adjugate experts: {chunk_experts}")

if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")


@ModelBase.register("ChameleonForConditionalGeneration")
@ModelBase.register("ChameleonForCausalLM") # obsolete
class ChameleonModel(TextModel):
Expand Down
31 changes: 31 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class LLM:
FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
EXPERT_FEED_FORWARD_LENGTH = "{arch}.expert_feed_forward_length"
EXPERT_SHARED_FEED_FORWARD_LENGTH = "{arch}.expert_shared_feed_forward_length"
EXPERT_CHUNK_FEED_FORWARD_LENGTH = "{arch}.expert_chunk_feed_forward_length"
USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
EXPERT_COUNT = "{arch}.expert_count"
Expand All @@ -104,6 +105,8 @@ class LLM:
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm"
EXPERT_GATING_FUNC = "{arch}.expert_gating_func"
EXPERT_GROUP_SCALE = "{arch}.expert_group_scale"
EXPERTS_PER_GROUP = "{arch}.experts_per_group"
MOE_EVERY_N_LAYERS = "{arch}.moe_every_n_layers"
NEXTN_PREDICT_LAYERS = "{arch}.nextn_predict_layers"
POOLING_TYPE = "{arch}.pooling_type"
Expand Down Expand Up @@ -392,6 +395,7 @@ class MODEL_ARCH(IntEnum):
SMALLTHINKER = auto()
LLADA = auto()
SEED_OSS = auto()
GROVEMOE = auto()


class VISION_PROJECTOR_TYPE(IntEnum):
Expand Down Expand Up @@ -441,6 +445,9 @@ class MODEL_TENSOR(IntEnum):
FFN_GATE_SHEXP = auto()
FFN_DOWN_SHEXP = auto()
FFN_UP_SHEXP = auto()
FFN_GATE_CHEXP = auto()
FFN_DOWN_CHEXP = auto()
FFN_UP_CHEXP = auto()
FFN_EXP_PROBS_B = auto()
ATTN_Q_NORM = auto()
ATTN_K_NORM = auto()
Expand Down Expand Up @@ -728,6 +735,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.SMALLTHINKER: "smallthinker",
MODEL_ARCH.LLADA: "llada",
MODEL_ARCH.SEED_OSS: "seed_oss",
MODEL_ARCH.GROVEMOE: "grovemoe",
}

VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
Expand Down Expand Up @@ -774,6 +782,9 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_GATE_SHEXP: "blk.{bid}.ffn_gate_shexp",
MODEL_TENSOR.FFN_DOWN_SHEXP: "blk.{bid}.ffn_down_shexp",
MODEL_TENSOR.FFN_UP_SHEXP: "blk.{bid}.ffn_up_shexp",
MODEL_TENSOR.FFN_GATE_CHEXP: "blk.{bid}.ffn_gate_chexps",
MODEL_TENSOR.FFN_DOWN_CHEXP: "blk.{bid}.ffn_down_chexps",
MODEL_TENSOR.FFN_UP_CHEXP: "blk.{bid}.ffn_up_chexps",
MODEL_TENSOR.FFN_ACT: "blk.{bid}.ffn",
MODEL_TENSOR.FFN_NORM_EXP: "blk.{bid}.ffn_norm_exps",
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps",
Expand Down Expand Up @@ -2684,6 +2695,26 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
MODEL_ARCH.GROVEMOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_CHEXP,
MODEL_TENSOR.FFN_DOWN_CHEXP,
MODEL_TENSOR.FFN_UP_CHEXP,
],
# TODO
}

Expand Down
9 changes: 9 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,9 @@ def add_expert_feed_forward_length(self, length: int) -> None:
def add_expert_shared_feed_forward_length(self, length: int) -> None:
self.add_uint32(Keys.LLM.EXPERT_SHARED_FEED_FORWARD_LENGTH.format(arch=self.arch), length)

def add_expert_chunk_feed_forward_length(self, length: int) -> None:
self.add_uint32(Keys.LLM.EXPERT_CHUNK_FEED_FORWARD_LENGTH.format(arch=self.arch), length)

def add_parallel_residual(self, use: bool) -> None:
self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)

Expand Down Expand Up @@ -751,6 +754,12 @@ def add_expert_weights_norm(self, value: bool) -> None:
def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None:
self.add_uint32(Keys.LLM.EXPERT_GATING_FUNC.format(arch=self.arch), value.value)

def add_expert_group_scale(self, value: float) -> None:
self.add_float32(Keys.LLM.EXPERT_GROUP_SCALE.format(arch=self.arch), value)

def add_experts_per_group(self, count: int) -> None:
self.add_uint32(Keys.LLM.EXPERTS_PER_GROUP.format(arch=self.arch), count)

def add_moe_every_n_layers(self, value: int) -> None:
self.add_uint32(Keys.LLM.MOE_EVERY_N_LAYERS.format(arch=self.arch), value)

Expand Down
12 changes: 12 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,10 @@ class TensorNameMap:
"model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan
),

MODEL_TENSOR.FFN_UP_CHEXP: (
"model.layers.{bid}.mlp.chunk_experts.up_proj", # grovemoe
),

# AWQ-activation gate
MODEL_TENSOR.FFN_ACT: (
"transformer.blocks.{bid}.ffn.act", # mpt
Expand Down Expand Up @@ -464,6 +468,10 @@ class TensorNameMap:
"model.layers.{bid}.mlp.shared_mlp.gate_proj", # hunyuan
),

MODEL_TENSOR.FFN_GATE_CHEXP: (
"model.layers.{bid}.mlp.chunk_experts.gate_proj", # grovemoe
),

# Feed-forward down
MODEL_TENSOR.FFN_DOWN: (
"gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox
Expand Down Expand Up @@ -520,6 +528,10 @@ class TensorNameMap:
"model.layers.{bid}.mlp.shared_mlp.down_proj", # hunyuan
),

MODEL_TENSOR.FFN_DOWN_CHEXP: (
"model.layers.{bid}.mlp.chunk_experts.down_proj", # grovemoe
),

MODEL_TENSOR.ATTN_Q_NORM: (
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
Expand Down
30 changes: 30 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_SMALLTHINKER, "smallthinker" },
{ LLM_ARCH_LLADA, "llada" },
{ LLM_ARCH_SEED_OSS, "seed_oss" },
{ LLM_ARCH_GROVEMOE, "grovemoe" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};

Expand Down Expand Up @@ -124,6 +125,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" },
{ LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" },
{ LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, "%s.expert_shared_feed_forward_length" },
{ LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, "%s.expert_chunk_feed_forward_length" },
{ LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" },
{ LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" },
{ LLM_KV_EXPERT_COUNT, "%s.expert_count" },
Expand All @@ -132,6 +134,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
{ LLM_KV_EXPERT_GROUP_SCALE, "%s.expert_group_scale" },
{ LLM_KV_EXPERTS_PER_GROUP, "%s.experts_per_group" },
{ LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" },
{ LLM_KV_NEXTN_PREDICT_LAYERS, "%s.nextn_predict_layers" },
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
Expand Down Expand Up @@ -2152,6 +2156,29 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_GROVEMOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_CHEXPS, "blk.%d.ffn_gate_chexps" },
{ LLM_TENSOR_FFN_DOWN_CHEXPS, "blk.%d.ffn_down_chexps" },
{ LLM_TENSOR_FFN_UP_CHEXPS, "blk.%d.ffn_up_chexps" },
},
},
{
LLM_ARCH_UNKNOWN,
{
Expand Down Expand Up @@ -2284,6 +2311,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_FFN_DOWN_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_DOWN_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_GATE_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_UP_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
// altup / laurel (gemma 3n)
{LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
Expand Down
7 changes: 7 additions & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ enum llm_arch {
LLM_ARCH_SMALLTHINKER,
LLM_ARCH_LLADA,
LLM_ARCH_SEED_OSS,
LLM_ARCH_GROVEMOE,
LLM_ARCH_UNKNOWN,
};

Expand Down Expand Up @@ -128,6 +129,7 @@ enum llm_kv {
LLM_KV_FEED_FORWARD_LENGTH,
LLM_KV_EXPERT_FEED_FORWARD_LENGTH,
LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH,
LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH,
LLM_KV_USE_PARALLEL_RESIDUAL,
LLM_KV_TENSOR_DATA_LAYOUT,
LLM_KV_EXPERT_COUNT,
Expand All @@ -136,6 +138,8 @@ enum llm_kv {
LLM_KV_EXPERT_WEIGHTS_SCALE,
LLM_KV_EXPERT_WEIGHTS_NORM,
LLM_KV_EXPERT_GATING_FUNC,
LLM_KV_EXPERT_GROUP_SCALE,
LLM_KV_EXPERTS_PER_GROUP,
LLM_KV_MOE_EVERY_N_LAYERS,
LLM_KV_NEXTN_PREDICT_LAYERS,
LLM_KV_POOLING_TYPE,
Expand Down Expand Up @@ -292,6 +296,9 @@ enum llm_tensor {
LLM_TENSOR_FFN_DOWN_SHEXP,
LLM_TENSOR_FFN_GATE_SHEXP,
LLM_TENSOR_FFN_UP_SHEXP,
LLM_TENSOR_FFN_DOWN_CHEXPS,
LLM_TENSOR_FFN_GATE_CHEXPS,
LLM_TENSOR_FFN_UP_CHEXPS,
LLM_TENSOR_FFN_EXP_PROBS_B,
LLM_TENSOR_ATTN_Q_NORM,
LLM_TENSOR_ATTN_K_NORM,
Expand Down
Loading
Loading