Skip to content

Commit 2aed91c

Browse files
committed
fix: Consoladate GraniteMoEShared into GraniteMoE for conversion
Branch: GraniteMoEShared Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 3d79214 commit 2aed91c

File tree

3 files changed

+148
-180
lines changed

3 files changed

+148
-180
lines changed

convert_hf_to_gguf.py

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5669,11 +5669,20 @@ def set_gguf_parameters(self):
56695669
logger.info("gguf: (granite) logits_scale = %s", logits_scale)
56705670

56715671

5672-
@ModelBase.register("GraniteMoeForCausalLM")
5672+
@ModelBase.register("GraniteMoeForCausalLM", "GraniteMoeSharedForCausalLM")
56735673
class GraniteMoeModel(GraniteModel):
56745674
"""Conversion for IBM's GraniteMoeForCausalLM"""
56755675
model_arch = gguf.MODEL_ARCH.GRANITE_MOE
56765676

5677+
def set_gguf_parameters(self):
5678+
"""GraniteMoeShared uses GraniteMoe parameters plus the following:
5679+
- shared_intermediate_size
5680+
"""
5681+
super().set_gguf_parameters()
5682+
if shared_feed_forward_length := self.hparams.get("shared_intermediate_size"):
5683+
self.gguf_writer.add_expert_shared_feed_forward_length(shared_feed_forward_length)
5684+
logger.info("gguf: (granitemoeshared) shared_feed_forward_length = %s", shared_feed_forward_length)
5685+
56775686
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
56785687
"""In modeling_granitemoe, the JetMoe implementation of parallel experts
56795688
is used. This essentially merges w1 and w3 into a single tensor with 2x
@@ -5684,36 +5693,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
56845693
if name.endswith("block_sparse_moe.input_linear.weight"):
56855694
ffn_dim = self.hparams["intermediate_size"]
56865695
assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * intermediate_size"
5687-
gate, up = data_torch[..., :ffn_dim, :], data_torch[..., ffn_dim:, :]
5696+
gate, up = data_torch.split(ffn_dim, dim=-2)
56885697
return [
56895698
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXP, bid), gate),
56905699
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), up),
56915700
]
56925701

5693-
return super().modify_tensors(data_torch, name, bid)
5694-
5695-
5696-
@ModelBase.register("GraniteMoeSharedForCausalLM")
5697-
class GraniteMoeSharedModel(GraniteMoeModel):
5698-
"""Conversion for IBM's GraniteMoeSharedForCausalLM"""
5699-
model_arch = gguf.MODEL_ARCH.GRANITE_MOE_SHARED
5700-
5701-
def set_gguf_parameters(self):
5702-
"""GraniteMoeShared uses GraniteMoe parameters plus the following:
5703-
- shared_intermediate_size
5704-
"""
5705-
super().set_gguf_parameters()
5706-
if shared_feed_forward_length := self.hparams.get("shared_intermediate_size"):
5707-
self.gguf_writer.add_expert_shared_feed_forward_length(shared_feed_forward_length)
5708-
logger.info("gguf: (granitemoeshared) shared_feed_forward_length = %s", shared_feed_forward_length)
5709-
5710-
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
5711-
"""In modeling_granitemoeshared, the implementation of parallel experts
5712-
is used. This essentially merges w1 and w3 into a single tensor with 2x
5713-
the hidden size that is then split during forward. To keep compatibility
5714-
with existing shared expert support, we pull them apart here.
5715-
"""
5716-
57175702
if name.endswith("shared_mlp.input_linear.weight"):
57185703
ffn_dim = self.hparams["shared_intermediate_size"]
57195704
assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * shared_intermediate_size"

gguf-py/gguf/constants.py

Lines changed: 136 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -255,75 +255,74 @@ class GGUFType:
255255

256256

257257
class MODEL_ARCH(IntEnum):
258-
CLIP_VISION = auto() # dummy arch for clip.cpp
259-
LLAMA = auto()
260-
LLAMA4 = auto()
261-
DECI = auto()
262-
FALCON = auto()
263-
BAICHUAN = auto()
264-
GROK = auto()
265-
GPT2 = auto()
266-
GPTJ = auto()
267-
GPTNEOX = auto()
268-
MPT = auto()
269-
STARCODER = auto()
270-
REFACT = auto()
271-
BERT = auto()
272-
NOMIC_BERT = auto()
273-
NOMIC_BERT_MOE = auto()
274-
JINA_BERT_V2 = auto()
275-
BLOOM = auto()
276-
STABLELM = auto()
277-
QWEN = auto()
278-
QWEN2 = auto()
279-
QWEN2MOE = auto()
280-
QWEN2VL = auto()
281-
QWEN3 = auto()
282-
QWEN3MOE = auto()
283-
PHI2 = auto()
284-
PHI3 = auto()
285-
PHIMOE = auto()
286-
PLAMO = auto()
287-
CODESHELL = auto()
288-
ORION = auto()
289-
INTERNLM2 = auto()
290-
MINICPM = auto()
291-
MINICPM3 = auto()
292-
GEMMA = auto()
293-
GEMMA2 = auto()
294-
GEMMA3 = auto()
295-
STARCODER2 = auto()
296-
RWKV6 = auto()
297-
RWKV6QWEN2 = auto()
298-
RWKV7 = auto()
299-
ARWKV7 = auto()
300-
MAMBA = auto()
301-
XVERSE = auto()
302-
COMMAND_R = auto()
303-
COHERE2 = auto()
304-
DBRX = auto()
305-
OLMO = auto()
306-
OLMO2 = auto()
307-
OLMOE = auto()
308-
OPENELM = auto()
309-
ARCTIC = auto()
310-
DEEPSEEK = auto()
311-
DEEPSEEK2 = auto()
312-
CHATGLM = auto()
313-
GLM4 = auto()
314-
BITNET = auto()
315-
T5 = auto()
316-
T5ENCODER = auto()
317-
JAIS = auto()
318-
NEMOTRON = auto()
319-
EXAONE = auto()
320-
GRANITE = auto()
321-
GRANITE_MOE = auto()
322-
GRANITE_MOE_SHARED = auto()
323-
CHAMELEON = auto()
324-
WAVTOKENIZER_DEC = auto()
325-
PLM = auto()
326-
BAILINGMOE = auto()
258+
CLIP_VISION = auto() # dummy arch for clip.cpp
259+
LLAMA = auto()
260+
LLAMA4 = auto()
261+
DECI = auto()
262+
FALCON = auto()
263+
BAICHUAN = auto()
264+
GROK = auto()
265+
GPT2 = auto()
266+
GPTJ = auto()
267+
GPTNEOX = auto()
268+
MPT = auto()
269+
STARCODER = auto()
270+
REFACT = auto()
271+
BERT = auto()
272+
NOMIC_BERT = auto()
273+
NOMIC_BERT_MOE = auto()
274+
JINA_BERT_V2 = auto()
275+
BLOOM = auto()
276+
STABLELM = auto()
277+
QWEN = auto()
278+
QWEN2 = auto()
279+
QWEN2MOE = auto()
280+
QWEN2VL = auto()
281+
QWEN3 = auto()
282+
QWEN3MOE = auto()
283+
PHI2 = auto()
284+
PHI3 = auto()
285+
PHIMOE = auto()
286+
PLAMO = auto()
287+
CODESHELL = auto()
288+
ORION = auto()
289+
INTERNLM2 = auto()
290+
MINICPM = auto()
291+
MINICPM3 = auto()
292+
GEMMA = auto()
293+
GEMMA2 = auto()
294+
GEMMA3 = auto()
295+
STARCODER2 = auto()
296+
RWKV6 = auto()
297+
RWKV6QWEN2 = auto()
298+
RWKV7 = auto()
299+
ARWKV7 = auto()
300+
MAMBA = auto()
301+
XVERSE = auto()
302+
COMMAND_R = auto()
303+
COHERE2 = auto()
304+
DBRX = auto()
305+
OLMO = auto()
306+
OLMO2 = auto()
307+
OLMOE = auto()
308+
OPENELM = auto()
309+
ARCTIC = auto()
310+
DEEPSEEK = auto()
311+
DEEPSEEK2 = auto()
312+
CHATGLM = auto()
313+
GLM4 = auto()
314+
BITNET = auto()
315+
T5 = auto()
316+
T5ENCODER = auto()
317+
JAIS = auto()
318+
NEMOTRON = auto()
319+
EXAONE = auto()
320+
GRANITE = auto()
321+
GRANITE_MOE = auto()
322+
CHAMELEON = auto()
323+
WAVTOKENIZER_DEC = auto()
324+
PLM = auto()
325+
BAILINGMOE = auto()
327326

328327

329328
class VISION_PROJECTOR_TYPE(IntEnum):
@@ -513,75 +512,74 @@ class MODEL_TENSOR(IntEnum):
513512

514513

515514
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
516-
MODEL_ARCH.CLIP_VISION: "clip", # dummy arch for clip.cpp
517-
MODEL_ARCH.LLAMA: "llama",
518-
MODEL_ARCH.LLAMA4: "llama4",
519-
MODEL_ARCH.DECI: "deci",
520-
MODEL_ARCH.FALCON: "falcon",
521-
MODEL_ARCH.BAICHUAN: "baichuan",
522-
MODEL_ARCH.GROK: "grok",
523-
MODEL_ARCH.GPT2: "gpt2",
524-
MODEL_ARCH.GPTJ: "gptj",
525-
MODEL_ARCH.GPTNEOX: "gptneox",
526-
MODEL_ARCH.MPT: "mpt",
527-
MODEL_ARCH.STARCODER: "starcoder",
528-
MODEL_ARCH.REFACT: "refact",
529-
MODEL_ARCH.BERT: "bert",
530-
MODEL_ARCH.NOMIC_BERT: "nomic-bert",
531-
MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe",
532-
MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2",
533-
MODEL_ARCH.BLOOM: "bloom",
534-
MODEL_ARCH.STABLELM: "stablelm",
535-
MODEL_ARCH.QWEN: "qwen",
536-
MODEL_ARCH.QWEN2: "qwen2",
537-
MODEL_ARCH.QWEN2MOE: "qwen2moe",
538-
MODEL_ARCH.QWEN2VL: "qwen2vl",
539-
MODEL_ARCH.QWEN3: "qwen3",
540-
MODEL_ARCH.QWEN3MOE: "qwen3moe",
541-
MODEL_ARCH.PHI2: "phi2",
542-
MODEL_ARCH.PHI3: "phi3",
543-
MODEL_ARCH.PHIMOE: "phimoe",
544-
MODEL_ARCH.PLAMO: "plamo",
545-
MODEL_ARCH.CODESHELL: "codeshell",
546-
MODEL_ARCH.ORION: "orion",
547-
MODEL_ARCH.INTERNLM2: "internlm2",
548-
MODEL_ARCH.MINICPM: "minicpm",
549-
MODEL_ARCH.MINICPM3: "minicpm3",
550-
MODEL_ARCH.GEMMA: "gemma",
551-
MODEL_ARCH.GEMMA2: "gemma2",
552-
MODEL_ARCH.GEMMA3: "gemma3",
553-
MODEL_ARCH.STARCODER2: "starcoder2",
554-
MODEL_ARCH.RWKV6: "rwkv6",
555-
MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2",
556-
MODEL_ARCH.RWKV7: "rwkv7",
557-
MODEL_ARCH.ARWKV7: "arwkv7",
558-
MODEL_ARCH.MAMBA: "mamba",
559-
MODEL_ARCH.XVERSE: "xverse",
560-
MODEL_ARCH.COMMAND_R: "command-r",
561-
MODEL_ARCH.COHERE2: "cohere2",
562-
MODEL_ARCH.DBRX: "dbrx",
563-
MODEL_ARCH.OLMO: "olmo",
564-
MODEL_ARCH.OLMO2: "olmo2",
565-
MODEL_ARCH.OLMOE: "olmoe",
566-
MODEL_ARCH.OPENELM: "openelm",
567-
MODEL_ARCH.ARCTIC: "arctic",
568-
MODEL_ARCH.DEEPSEEK: "deepseek",
569-
MODEL_ARCH.DEEPSEEK2: "deepseek2",
570-
MODEL_ARCH.CHATGLM: "chatglm",
571-
MODEL_ARCH.GLM4: "glm4",
572-
MODEL_ARCH.BITNET: "bitnet",
573-
MODEL_ARCH.T5: "t5",
574-
MODEL_ARCH.T5ENCODER: "t5encoder",
575-
MODEL_ARCH.JAIS: "jais",
576-
MODEL_ARCH.NEMOTRON: "nemotron",
577-
MODEL_ARCH.EXAONE: "exaone",
578-
MODEL_ARCH.GRANITE: "granite",
579-
MODEL_ARCH.GRANITE_MOE: "granitemoe",
580-
MODEL_ARCH.GRANITE_MOE_SHARED: "granitemoeshared",
581-
MODEL_ARCH.CHAMELEON: "chameleon",
582-
MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
583-
MODEL_ARCH.PLM: "plm",
584-
MODEL_ARCH.BAILINGMOE: "bailingmoe",
515+
MODEL_ARCH.CLIP_VISION: "clip", # dummy arch for clip.cpp
516+
MODEL_ARCH.LLAMA: "llama",
517+
MODEL_ARCH.LLAMA4: "llama4",
518+
MODEL_ARCH.DECI: "deci",
519+
MODEL_ARCH.FALCON: "falcon",
520+
MODEL_ARCH.BAICHUAN: "baichuan",
521+
MODEL_ARCH.GROK: "grok",
522+
MODEL_ARCH.GPT2: "gpt2",
523+
MODEL_ARCH.GPTJ: "gptj",
524+
MODEL_ARCH.GPTNEOX: "gptneox",
525+
MODEL_ARCH.MPT: "mpt",
526+
MODEL_ARCH.STARCODER: "starcoder",
527+
MODEL_ARCH.REFACT: "refact",
528+
MODEL_ARCH.BERT: "bert",
529+
MODEL_ARCH.NOMIC_BERT: "nomic-bert",
530+
MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe",
531+
MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2",
532+
MODEL_ARCH.BLOOM: "bloom",
533+
MODEL_ARCH.STABLELM: "stablelm",
534+
MODEL_ARCH.QWEN: "qwen",
535+
MODEL_ARCH.QWEN2: "qwen2",
536+
MODEL_ARCH.QWEN2MOE: "qwen2moe",
537+
MODEL_ARCH.QWEN2VL: "qwen2vl",
538+
MODEL_ARCH.QWEN3: "qwen3",
539+
MODEL_ARCH.QWEN3MOE: "qwen3moe",
540+
MODEL_ARCH.PHI2: "phi2",
541+
MODEL_ARCH.PHI3: "phi3",
542+
MODEL_ARCH.PHIMOE: "phimoe",
543+
MODEL_ARCH.PLAMO: "plamo",
544+
MODEL_ARCH.CODESHELL: "codeshell",
545+
MODEL_ARCH.ORION: "orion",
546+
MODEL_ARCH.INTERNLM2: "internlm2",
547+
MODEL_ARCH.MINICPM: "minicpm",
548+
MODEL_ARCH.MINICPM3: "minicpm3",
549+
MODEL_ARCH.GEMMA: "gemma",
550+
MODEL_ARCH.GEMMA2: "gemma2",
551+
MODEL_ARCH.GEMMA3: "gemma3",
552+
MODEL_ARCH.STARCODER2: "starcoder2",
553+
MODEL_ARCH.RWKV6: "rwkv6",
554+
MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2",
555+
MODEL_ARCH.RWKV7: "rwkv7",
556+
MODEL_ARCH.ARWKV7: "arwkv7",
557+
MODEL_ARCH.MAMBA: "mamba",
558+
MODEL_ARCH.XVERSE: "xverse",
559+
MODEL_ARCH.COMMAND_R: "command-r",
560+
MODEL_ARCH.COHERE2: "cohere2",
561+
MODEL_ARCH.DBRX: "dbrx",
562+
MODEL_ARCH.OLMO: "olmo",
563+
MODEL_ARCH.OLMO2: "olmo2",
564+
MODEL_ARCH.OLMOE: "olmoe",
565+
MODEL_ARCH.OPENELM: "openelm",
566+
MODEL_ARCH.ARCTIC: "arctic",
567+
MODEL_ARCH.DEEPSEEK: "deepseek",
568+
MODEL_ARCH.DEEPSEEK2: "deepseek2",
569+
MODEL_ARCH.CHATGLM: "chatglm",
570+
MODEL_ARCH.GLM4: "glm4",
571+
MODEL_ARCH.BITNET: "bitnet",
572+
MODEL_ARCH.T5: "t5",
573+
MODEL_ARCH.T5ENCODER: "t5encoder",
574+
MODEL_ARCH.JAIS: "jais",
575+
MODEL_ARCH.NEMOTRON: "nemotron",
576+
MODEL_ARCH.EXAONE: "exaone",
577+
MODEL_ARCH.GRANITE: "granite",
578+
MODEL_ARCH.GRANITE_MOE: "granitemoe",
579+
MODEL_ARCH.CHAMELEON: "chameleon",
580+
MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
581+
MODEL_ARCH.PLM: "plm",
582+
MODEL_ARCH.BAILINGMOE: "bailingmoe",
585583
}
586584

587585
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@@ -1895,21 +1893,6 @@ class MODEL_TENSOR(IntEnum):
18951893
MODEL_TENSOR.FFN_GATE_EXP,
18961894
MODEL_TENSOR.FFN_DOWN_EXP,
18971895
MODEL_TENSOR.FFN_UP_EXP,
1898-
],
1899-
MODEL_ARCH.GRANITE_MOE_SHARED: [
1900-
MODEL_TENSOR.TOKEN_EMBD,
1901-
MODEL_TENSOR.OUTPUT_NORM,
1902-
MODEL_TENSOR.OUTPUT,
1903-
MODEL_TENSOR.ATTN_NORM,
1904-
MODEL_TENSOR.ATTN_Q,
1905-
MODEL_TENSOR.ATTN_K,
1906-
MODEL_TENSOR.ATTN_V,
1907-
MODEL_TENSOR.ATTN_OUT,
1908-
MODEL_TENSOR.FFN_NORM,
1909-
MODEL_TENSOR.FFN_GATE_INP,
1910-
MODEL_TENSOR.FFN_GATE_EXP,
1911-
MODEL_TENSOR.FFN_DOWN_EXP,
1912-
MODEL_TENSOR.FFN_UP_EXP,
19131896
MODEL_TENSOR.FFN_GATE_SHEXP,
19141897
MODEL_TENSOR.FFN_UP_SHEXP,
19151898
MODEL_TENSOR.FFN_DOWN_SHEXP,

gguf-py/gguf/tensor_mapping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ class TensorNameMap:
428428
"model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe
429429
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2
430430
"language_model.model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4
431-
"model.layers.{bid}.shared_mlp.output_linear", # granitemoeshared
431+
"model.layers.{bid}.shared_mlp.output_linear", # granitemoe
432432
),
433433

434434
MODEL_TENSOR.ATTN_Q_NORM: (

0 commit comments

Comments
 (0)