Skip to content

Commit 991b8e6

Browse files
CISCpwilkin
authored andcommitted
model : add GroveMoE support (ggml-org#15510)
* add GroveMoE support * remove constexpr that fails on certain compilers * revert crude scalar div implementation, use cast * build_attn_inp_kv_unified -> build_attn_inp_kv * fix build_attn * re-apply ffn_exps regex changes
1 parent 7bf6bdc commit 991b8e6

File tree

11 files changed

+13210
-178
lines changed

11 files changed

+13210
-178
lines changed

common/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -740,7 +740,7 @@ const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
740740
// MoE utils
741741
//
742742

743-
const char * const LLM_FFN_EXPS_REGEX = "\\.ffn_(up|down|gate)_exps";
743+
const char * const LLM_FFN_EXPS_REGEX = "\\.ffn_(up|down|gate)_(ch|)exps";
744744

745745
static std::string llm_ffn_exps_block_regex(int idx) {
746746
return string_format("blk\\.%d%s", idx, LLM_FFN_EXPS_REGEX);

convert_hf_to_gguf.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8025,6 +8025,121 @@ def prepare_tensors(self):
80258025
raise ValueError(f"Unprocessed experts: {experts}")
80268026

80278027

8028+
@ModelBase.register("GroveMoeForCausalLM", "modeling_grove_moe.GroveMoeForCausalLM")
8029+
class GroveMoeModel(TextModel):
8030+
model_arch = gguf.MODEL_ARCH.GROVEMOE
8031+
8032+
def set_gguf_parameters(self):
8033+
super().set_gguf_parameters()
8034+
if (n_experts := self.hparams.get("num_experts")) is not None:
8035+
self.gguf_writer.add_expert_count(n_experts)
8036+
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
8037+
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
8038+
logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}")
8039+
# FIXME?: Hardcoded https://huggingface.co/inclusionAI/GroveMoE-Inst/blob/c4c69e5970d18907b5e6ddccdfd55176fe292df1/modeling_grove_moe.py#L299
8040+
self.gguf_writer.add_expert_chunk_feed_forward_length(self.hparams.get("head_dim") or 128)
8041+
# FIXME?: Hardcoded https://huggingface.co/inclusionAI/GroveMoE-Inst/blob/c4c69e5970d18907b5e6ddccdfd55176fe292df1/modeling_grove_moe.py#L298
8042+
self.gguf_writer.add_experts_per_group(2)
8043+
# FIXME?: Hardcoded https://huggingface.co/inclusionAI/GroveMoE-Inst/blob/c4c69e5970d18907b5e6ddccdfd55176fe292df1/modeling_grove_moe.py#L376
8044+
self.gguf_writer.add_expert_group_scale(0.05)
8045+
# YaRN is not enabled by default
8046+
# To enable it, please refer to this guide: https://huggingface.co/Qwen/Qwen3-30B-A3B#processing-long-texts
8047+
rope_scaling = self.hparams.get("rope_scaling") or {}
8048+
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
8049+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
8050+
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
8051+
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
8052+
8053+
_experts: list[dict[str, Tensor]] | None = None
8054+
_chunk_experts: list[dict[str, Tensor]] | None = None
8055+
8056+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
8057+
if name.endswith(".expert_bias"):
8058+
# FIXME?: Unused https://huggingface.co/inclusionAI/GroveMoE-Inst/blob/c4c69e5970d18907b5e6ddccdfd55176fe292df1/modeling_grove_moe.py#L303
8059+
return []
8060+
8061+
# process the experts separately
8062+
if name.find("chunk_experts") != -1:
8063+
n_experts = self.hparams["num_experts"] // 2 # see add_experts_per_group
8064+
assert bid is not None
8065+
8066+
if self._chunk_experts is None:
8067+
self._chunk_experts = [{} for _ in range(self.block_count)]
8068+
8069+
self._chunk_experts[bid][name] = data_torch
8070+
8071+
if len(self._chunk_experts[bid]) >= n_experts * 3:
8072+
tensors: list[tuple[str, Tensor]] = []
8073+
8074+
# merge the experts into a single 3d tensor
8075+
for w_name in ["down_proj", "gate_proj", "up_proj"]:
8076+
datas: list[Tensor] = []
8077+
8078+
for xid in range(n_experts):
8079+
ename = f"model.layers.{bid}.mlp.chunk_experts.{xid}.{w_name}.weight"
8080+
datas.append(self._chunk_experts[bid][ename])
8081+
del self._chunk_experts[bid][ename]
8082+
8083+
data_torch = torch.stack(datas, dim=0)
8084+
8085+
merged_name = f"model.layers.{bid}.mlp.chunk_experts.{w_name}.weight"
8086+
8087+
new_name = self.map_tensor_name(merged_name)
8088+
8089+
tensors.append((new_name, data_torch))
8090+
return tensors
8091+
else:
8092+
return []
8093+
elif name.find("experts") != -1:
8094+
n_experts = self.hparams["num_experts"]
8095+
assert bid is not None
8096+
8097+
if self._experts is None:
8098+
self._experts = [{} for _ in range(self.block_count)]
8099+
8100+
self._experts[bid][name] = data_torch
8101+
8102+
if len(self._experts[bid]) >= n_experts * 3:
8103+
tensors: list[tuple[str, Tensor]] = []
8104+
8105+
# merge the experts into a single 3d tensor
8106+
for w_name in ["down_proj", "gate_proj", "up_proj"]:
8107+
datas: list[Tensor] = []
8108+
8109+
for xid in range(n_experts):
8110+
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
8111+
datas.append(self._experts[bid][ename])
8112+
del self._experts[bid][ename]
8113+
8114+
data_torch = torch.stack(datas, dim=0)
8115+
8116+
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
8117+
8118+
new_name = self.map_tensor_name(merged_name)
8119+
8120+
tensors.append((new_name, data_torch))
8121+
return tensors
8122+
else:
8123+
return []
8124+
8125+
return [(self.map_tensor_name(name), data_torch)]
8126+
8127+
def prepare_tensors(self):
8128+
super().prepare_tensors()
8129+
8130+
if self._chunk_experts is not None:
8131+
# flatten `list[dict[str, Tensor]]` into `list[str]`
8132+
chunk_experts = [k for d in self._chunk_experts for k in d.keys()]
8133+
if len(chunk_experts) > 0:
8134+
raise ValueError(f"Unprocessed adjugate experts: {chunk_experts}")
8135+
8136+
if self._experts is not None:
8137+
# flatten `list[dict[str, Tensor]]` into `list[str]`
8138+
experts = [k for d in self._experts for k in d.keys()]
8139+
if len(experts) > 0:
8140+
raise ValueError(f"Unprocessed experts: {experts}")
8141+
8142+
80288143
@ModelBase.register("ChameleonForConditionalGeneration")
80298144
@ModelBase.register("ChameleonForCausalLM") # obsolete
80308145
class ChameleonModel(TextModel):

gguf-py/gguf/constants.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ class LLM:
9696
FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
9797
EXPERT_FEED_FORWARD_LENGTH = "{arch}.expert_feed_forward_length"
9898
EXPERT_SHARED_FEED_FORWARD_LENGTH = "{arch}.expert_shared_feed_forward_length"
99+
EXPERT_CHUNK_FEED_FORWARD_LENGTH = "{arch}.expert_chunk_feed_forward_length"
99100
USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
100101
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
101102
EXPERT_COUNT = "{arch}.expert_count"
@@ -104,6 +105,8 @@ class LLM:
104105
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
105106
EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm"
106107
EXPERT_GATING_FUNC = "{arch}.expert_gating_func"
108+
EXPERT_GROUP_SCALE = "{arch}.expert_group_scale"
109+
EXPERTS_PER_GROUP = "{arch}.experts_per_group"
107110
MOE_EVERY_N_LAYERS = "{arch}.moe_every_n_layers"
108111
NEXTN_PREDICT_LAYERS = "{arch}.nextn_predict_layers"
109112
POOLING_TYPE = "{arch}.pooling_type"
@@ -402,6 +405,7 @@ class MODEL_ARCH(IntEnum):
402405
LLADA = auto()
403406
LLADA_MOE = auto()
404407
SEED_OSS = auto()
408+
GROVEMOE = auto()
405409

406410

407411
class VISION_PROJECTOR_TYPE(IntEnum):
@@ -452,6 +456,9 @@ class MODEL_TENSOR(IntEnum):
452456
FFN_GATE_SHEXP = auto()
453457
FFN_DOWN_SHEXP = auto()
454458
FFN_UP_SHEXP = auto()
459+
FFN_GATE_CHEXP = auto()
460+
FFN_DOWN_CHEXP = auto()
461+
FFN_UP_CHEXP = auto()
455462
FFN_EXP_PROBS_B = auto()
456463
ATTN_Q_NORM = auto()
457464
ATTN_K_NORM = auto()
@@ -742,6 +749,7 @@ class MODEL_TENSOR(IntEnum):
742749
MODEL_ARCH.LLADA: "llada",
743750
MODEL_ARCH.LLADA_MOE: "llada-moe",
744751
MODEL_ARCH.SEED_OSS: "seed_oss",
752+
MODEL_ARCH.GROVEMOE: "grovemoe",
745753
}
746754

747755
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@@ -789,6 +797,9 @@ class MODEL_TENSOR(IntEnum):
789797
MODEL_TENSOR.FFN_GATE_SHEXP: "blk.{bid}.ffn_gate_shexp",
790798
MODEL_TENSOR.FFN_DOWN_SHEXP: "blk.{bid}.ffn_down_shexp",
791799
MODEL_TENSOR.FFN_UP_SHEXP: "blk.{bid}.ffn_up_shexp",
800+
MODEL_TENSOR.FFN_GATE_CHEXP: "blk.{bid}.ffn_gate_chexps",
801+
MODEL_TENSOR.FFN_DOWN_CHEXP: "blk.{bid}.ffn_down_chexps",
802+
MODEL_TENSOR.FFN_UP_CHEXP: "blk.{bid}.ffn_up_chexps",
792803
MODEL_TENSOR.FFN_ACT: "blk.{bid}.ffn",
793804
MODEL_TENSOR.FFN_NORM_EXP: "blk.{bid}.ffn_norm_exps",
794805
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps",
@@ -2747,6 +2758,26 @@ class MODEL_TENSOR(IntEnum):
27472758
MODEL_TENSOR.FFN_UP_EXP,
27482759
MODEL_TENSOR.FFN_DOWN_EXP,
27492760
],
2761+
MODEL_ARCH.GROVEMOE: [
2762+
MODEL_TENSOR.TOKEN_EMBD,
2763+
MODEL_TENSOR.OUTPUT_NORM,
2764+
MODEL_TENSOR.OUTPUT,
2765+
MODEL_TENSOR.ATTN_NORM,
2766+
MODEL_TENSOR.ATTN_Q,
2767+
MODEL_TENSOR.ATTN_Q_NORM,
2768+
MODEL_TENSOR.ATTN_K,
2769+
MODEL_TENSOR.ATTN_K_NORM,
2770+
MODEL_TENSOR.ATTN_V,
2771+
MODEL_TENSOR.ATTN_OUT,
2772+
MODEL_TENSOR.FFN_NORM,
2773+
MODEL_TENSOR.FFN_GATE_INP,
2774+
MODEL_TENSOR.FFN_GATE_EXP,
2775+
MODEL_TENSOR.FFN_DOWN_EXP,
2776+
MODEL_TENSOR.FFN_UP_EXP,
2777+
MODEL_TENSOR.FFN_GATE_CHEXP,
2778+
MODEL_TENSOR.FFN_DOWN_CHEXP,
2779+
MODEL_TENSOR.FFN_UP_CHEXP,
2780+
],
27502781
# TODO
27512782
}
27522783

gguf-py/gguf/gguf_writer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,9 @@ def add_expert_feed_forward_length(self, length: int) -> None:
670670
def add_expert_shared_feed_forward_length(self, length: int) -> None:
671671
self.add_uint32(Keys.LLM.EXPERT_SHARED_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
672672

673+
def add_expert_chunk_feed_forward_length(self, length: int) -> None:
674+
self.add_uint32(Keys.LLM.EXPERT_CHUNK_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
675+
673676
def add_parallel_residual(self, use: bool) -> None:
674677
self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
675678

@@ -757,6 +760,12 @@ def add_expert_weights_norm(self, value: bool) -> None:
757760
def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None:
758761
self.add_uint32(Keys.LLM.EXPERT_GATING_FUNC.format(arch=self.arch), value.value)
759762

763+
def add_expert_group_scale(self, value: float) -> None:
764+
self.add_float32(Keys.LLM.EXPERT_GROUP_SCALE.format(arch=self.arch), value)
765+
766+
def add_experts_per_group(self, count: int) -> None:
767+
self.add_uint32(Keys.LLM.EXPERTS_PER_GROUP.format(arch=self.arch), count)
768+
760769
def add_moe_every_n_layers(self, value: int) -> None:
761770
self.add_uint32(Keys.LLM.MOE_EVERY_N_LAYERS.format(arch=self.arch), value)
762771

gguf-py/gguf/tensor_mapping.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,10 @@ class TensorNameMap:
427427
"model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan
428428
),
429429

430+
MODEL_TENSOR.FFN_UP_CHEXP: (
431+
"model.layers.{bid}.mlp.chunk_experts.up_proj", # grovemoe
432+
),
433+
430434
# AWQ-activation gate
431435
MODEL_TENSOR.FFN_ACT: (
432436
"transformer.blocks.{bid}.ffn.act", # mpt
@@ -468,6 +472,10 @@ class TensorNameMap:
468472
"model.layers.{bid}.mlp.shared_mlp.gate_proj", # hunyuan
469473
),
470474

475+
MODEL_TENSOR.FFN_GATE_CHEXP: (
476+
"model.layers.{bid}.mlp.chunk_experts.gate_proj", # grovemoe
477+
),
478+
471479
# Feed-forward down
472480
MODEL_TENSOR.FFN_DOWN: (
473481
"gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox
@@ -524,6 +532,10 @@ class TensorNameMap:
524532
"model.layers.{bid}.mlp.shared_mlp.down_proj", # hunyuan
525533
),
526534

535+
MODEL_TENSOR.FFN_DOWN_CHEXP: (
536+
"model.layers.{bid}.mlp.chunk_experts.down_proj", # grovemoe
537+
),
538+
527539
MODEL_TENSOR.ATTN_Q_NORM: (
528540
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
529541
"model.layers.{bid}.self_attn.q_layernorm", # persimmon

src/llama-arch.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
9999
{ LLM_ARCH_LLADA, "llada" },
100100
{ LLM_ARCH_LLADA_MOE, "llada-moe" },
101101
{ LLM_ARCH_SEED_OSS, "seed_oss" },
102+
{ LLM_ARCH_GROVEMOE, "grovemoe" },
102103
{ LLM_ARCH_UNKNOWN, "(unknown)" },
103104
};
104105

@@ -126,6 +127,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
126127
{ LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" },
127128
{ LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" },
128129
{ LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, "%s.expert_shared_feed_forward_length" },
130+
{ LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, "%s.expert_chunk_feed_forward_length" },
129131
{ LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" },
130132
{ LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" },
131133
{ LLM_KV_EXPERT_COUNT, "%s.expert_count" },
@@ -134,6 +136,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
134136
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
135137
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
136138
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
139+
{ LLM_KV_EXPERT_GROUP_SCALE, "%s.expert_group_scale" },
140+
{ LLM_KV_EXPERTS_PER_GROUP, "%s.experts_per_group" },
137141
{ LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" },
138142
{ LLM_KV_NEXTN_PREDICT_LAYERS, "%s.nextn_predict_layers" },
139143
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
@@ -2219,6 +2223,29 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
22192223
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
22202224
},
22212225
},
2226+
{
2227+
LLM_ARCH_GROVEMOE,
2228+
{
2229+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
2230+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
2231+
{ LLM_TENSOR_OUTPUT, "output" },
2232+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
2233+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
2234+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
2235+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
2236+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
2237+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
2238+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
2239+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
2240+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
2241+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
2242+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
2243+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
2244+
{ LLM_TENSOR_FFN_GATE_CHEXPS, "blk.%d.ffn_gate_chexps" },
2245+
{ LLM_TENSOR_FFN_DOWN_CHEXPS, "blk.%d.ffn_down_chexps" },
2246+
{ LLM_TENSOR_FFN_UP_CHEXPS, "blk.%d.ffn_up_chexps" },
2247+
},
2248+
},
22222249
{
22232250
LLM_ARCH_UNKNOWN,
22242251
{
@@ -2352,6 +2379,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
23522379
{LLM_TENSOR_FFN_DOWN_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
23532380
{LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
23542381
{LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
2382+
{LLM_TENSOR_FFN_DOWN_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
2383+
{LLM_TENSOR_FFN_GATE_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
2384+
{LLM_TENSOR_FFN_UP_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
23552385
{LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
23562386
// altup / laurel (gemma 3n)
23572387
{LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},

src/llama-arch.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ enum llm_arch {
103103
LLM_ARCH_LLADA,
104104
LLM_ARCH_LLADA_MOE,
105105
LLM_ARCH_SEED_OSS,
106+
LLM_ARCH_GROVEMOE,
106107
LLM_ARCH_UNKNOWN,
107108
};
108109

@@ -130,6 +131,7 @@ enum llm_kv {
130131
LLM_KV_FEED_FORWARD_LENGTH,
131132
LLM_KV_EXPERT_FEED_FORWARD_LENGTH,
132133
LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH,
134+
LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH,
133135
LLM_KV_USE_PARALLEL_RESIDUAL,
134136
LLM_KV_TENSOR_DATA_LAYOUT,
135137
LLM_KV_EXPERT_COUNT,
@@ -138,6 +140,8 @@ enum llm_kv {
138140
LLM_KV_EXPERT_WEIGHTS_SCALE,
139141
LLM_KV_EXPERT_WEIGHTS_NORM,
140142
LLM_KV_EXPERT_GATING_FUNC,
143+
LLM_KV_EXPERT_GROUP_SCALE,
144+
LLM_KV_EXPERTS_PER_GROUP,
141145
LLM_KV_MOE_EVERY_N_LAYERS,
142146
LLM_KV_NEXTN_PREDICT_LAYERS,
143147
LLM_KV_POOLING_TYPE,
@@ -302,6 +306,9 @@ enum llm_tensor {
302306
LLM_TENSOR_FFN_DOWN_SHEXP,
303307
LLM_TENSOR_FFN_GATE_SHEXP,
304308
LLM_TENSOR_FFN_UP_SHEXP,
309+
LLM_TENSOR_FFN_DOWN_CHEXPS,
310+
LLM_TENSOR_FFN_GATE_CHEXPS,
311+
LLM_TENSOR_FFN_UP_CHEXPS,
305312
LLM_TENSOR_FFN_EXP_PROBS_B,
306313
LLM_TENSOR_ATTN_Q_NORM,
307314
LLM_TENSOR_ATTN_K_NORM,

src/llama-graph.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -923,13 +923,26 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
923923
selection_probs = logits;
924924
}
925925

926+
if (arch == LLM_ARCH_GROVEMOE) {
927+
selection_probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
928+
cb(selection_probs, "ffn_moe_probs_biased", il);
929+
}
930+
926931
// select experts
927932
ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
928933
cb(selected_experts->src[0], "ffn_moe_argsort", il);
929934
cb(selected_experts, "ffn_moe_topk", il);
930935

931-
ggml_tensor * weights = ggml_get_rows(ctx0,
932-
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
936+
if (arch == LLM_ARCH_GROVEMOE && n_expert != hparams.n_expert) {
937+
// TODO: Use scalar div instead when/if implemented
938+
ggml_tensor * f_sel = ggml_cast(ctx0, selected_experts, GGML_TYPE_F32);
939+
selected_experts = ggml_cast(ctx0, ggml_scale(ctx0, f_sel, 1.0f / float(hparams.n_group_experts)), GGML_TYPE_I32);
940+
probs = ggml_reshape_3d(ctx0, probs, 1, hparams.n_expert, n_tokens);
941+
} else {
942+
probs = ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens);
943+
}
944+
945+
ggml_tensor * weights = ggml_get_rows(ctx0, probs, selected_experts); // [1, n_expert_used, n_tokens]
933946
cb(weights, "ffn_moe_weights", il);
934947

935948

0 commit comments

Comments
 (0)