Skip to content

Commit 1b61ac9

Browse files
committed
refactor(py): Simplify granitemoehybrid conversion to use parents better
Branch: GraniteFour Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 1c15138 commit 1b61ac9

File tree

1 file changed

+7
-46
lines changed

1 file changed

+7
-46
lines changed

convert_hf_to_gguf.py

Lines changed: 7 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6260,10 +6260,6 @@ class GraniteMoeHybridModel(BambaModel, GraniteMoeModel):
62606260
SSM layers"""
62616261
model_arch = gguf.MODEL_ARCH.GRANITE_MOE_HYBRID
62626262

6263-
def __init__(self, *args, **kwargs):
6264-
super().__init__(*args, **kwargs)
6265-
self._transformer_model_class = GraniteMoeModel
6266-
62676263
def get_attn_layres(self):
62686264
if layer_types := self.hparams.get("layer_types"):
62696265
return [
@@ -6275,51 +6271,16 @@ def get_attn_layres(self):
62756271
def modify_tensors(
62766272
self, data_torch: Tensor, name: str, bid: int | None
62776273
) -> Iterable[tuple[str, Tensor]]:
6278-
6279-
# In GraniteMoeHybrid, the mamba layers also have an MoE + Shared expert
6280-
if name.endswith("block_sparse_moe.input_linear.weight"):
6281-
ffn_dim = self.hparams["intermediate_size"]
6282-
assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * intermediate_size"
6283-
gate, up = data_torch[..., :ffn_dim, :], data_torch[..., ffn_dim:, :]
6284-
return [
6285-
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXP, bid), gate),
6286-
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), up),
6287-
]
6288-
if name.endswith("shared_mlp.input_linear.weight"):
6289-
ffn_dim = self.hparams["shared_intermediate_size"]
6290-
assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * shared_intermediate_size"
6291-
gate, up = data_torch.split(ffn_dim, dim=-2)
6292-
return [
6293-
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SHEXP, bid), gate),
6294-
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SHEXP, bid), up),
6295-
]
6296-
6274+
if (
6275+
name.endswith("block_sparse_moe.input_linear.weight") or
6276+
name.endswith("shared_mlp.input_linear.weight")
6277+
):
6278+
return GraniteMoeModel.modify_tensors(self, data_torch, name, bid)
62976279
return super().modify_tensors(data_torch, name, bid)
62986280

6299-
63006281
def set_gguf_parameters(self):
6301-
super().set_gguf_parameters()
6302-
if attention_scale := self.hparams.get("attention_multiplier"):
6303-
self.gguf_writer.add_attention_scale(attention_scale)
6304-
logger.info("gguf: (granite) attention_scale = %s", attention_scale)
6305-
if embedding_scale := self.hparams.get("embedding_multiplier"):
6306-
self.gguf_writer.add_embedding_scale(embedding_scale)
6307-
logger.info("gguf: (granite) embedding_scale = %s", embedding_scale)
6308-
if residual_scale := self.hparams.get("residual_multiplier"):
6309-
self.gguf_writer.add_residual_scale(residual_scale)
6310-
logger.info("gguf: (granite) residual_scale = %s", residual_scale)
6311-
if logits_scale := self.hparams.get("logits_scaling"):
6312-
self.gguf_writer.add_logit_scale(logits_scale)
6313-
logger.info("gguf: (granite) logits_scale = %s", logits_scale)
6314-
if shared_feed_forward_length := self.hparams.get("shared_intermediate_size"):
6315-
self.gguf_writer.add_expert_shared_feed_forward_length(shared_feed_forward_length)
6316-
logger.info("gguf: (granitemoeshared) shared_feed_forward_length = %s", shared_feed_forward_length)
6317-
if (n_experts := self.hparams.get("num_local_experts")) is not None:
6318-
self.gguf_writer.add_expert_count(n_experts)
6319-
logger.info(f"gguf: expert count = {n_experts}")
6320-
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
6321-
self.gguf_writer.add_expert_used_count(n_experts_used)
6322-
logger.info(f"gguf: experts used count = {n_experts_used}")
6282+
GraniteMoeModel.set_gguf_parameters(self)
6283+
BambaModel.set_gguf_parameters(self)
63236284

63246285
def set_vocab(self):
63256286
self.hparams["pad_vocab_size_multiple"] = 8

0 commit comments

Comments
 (0)