Skip to content

Commit 950e770

Browse files
committed
refactor(py): Simplify granitemoehybrid conversion to use parents better
Branch: GraniteFour Signed-off-by: Gabe Goodhart <[email protected]>
1 parent b9bc5f8 commit 950e770

File tree

1 file changed

+7
-46
lines changed

1 file changed

+7
-46
lines changed

convert_hf_to_gguf.py

Lines changed: 7 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6054,10 +6054,6 @@ class GraniteMoeHybridModel(BambaModel, GraniteMoeModel):
60546054
SSM layers"""
60556055
model_arch = gguf.MODEL_ARCH.GRANITE_MOE_HYBRID
60566056

6057-
def __init__(self, *args, **kwargs):
6058-
super().__init__(*args, **kwargs)
6059-
self._transformer_model_class = GraniteMoeModel
6060-
60616057
def get_attn_layres(self):
60626058
if layer_types := self.hparams.get("layer_types"):
60636059
return [
@@ -6069,51 +6065,16 @@ def get_attn_layres(self):
60696065
def modify_tensors(
60706066
self, data_torch: Tensor, name: str, bid: int | None
60716067
) -> Iterable[tuple[str, Tensor]]:
6072-
6073-
# In GraniteMoeHybrid, the mamba layers also have an MoE + Shared expert
6074-
if name.endswith("block_sparse_moe.input_linear.weight"):
6075-
ffn_dim = self.hparams["intermediate_size"]
6076-
assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * intermediate_size"
6077-
gate, up = data_torch[..., :ffn_dim, :], data_torch[..., ffn_dim:, :]
6078-
return [
6079-
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXP, bid), gate),
6080-
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), up),
6081-
]
6082-
if name.endswith("shared_mlp.input_linear.weight"):
6083-
ffn_dim = self.hparams["shared_intermediate_size"]
6084-
assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * shared_intermediate_size"
6085-
gate, up = data_torch.split(ffn_dim, dim=-2)
6086-
return [
6087-
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SHEXP, bid), gate),
6088-
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SHEXP, bid), up),
6089-
]
6090-
6068+
if (
6069+
name.endswith("block_sparse_moe.input_linear.weight") or
6070+
name.endswith("shared_mlp.input_linear.weight")
6071+
):
6072+
return GraniteMoeModel.modify_tensors(self, data_torch, name, bid)
60916073
return super().modify_tensors(data_torch, name, bid)
60926074

6093-
60946075
def set_gguf_parameters(self):
6095-
super().set_gguf_parameters()
6096-
if attention_scale := self.hparams.get("attention_multiplier"):
6097-
self.gguf_writer.add_attention_scale(attention_scale)
6098-
logger.info("gguf: (granite) attention_scale = %s", attention_scale)
6099-
if embedding_scale := self.hparams.get("embedding_multiplier"):
6100-
self.gguf_writer.add_embedding_scale(embedding_scale)
6101-
logger.info("gguf: (granite) embedding_scale = %s", embedding_scale)
6102-
if residual_scale := self.hparams.get("residual_multiplier"):
6103-
self.gguf_writer.add_residual_scale(residual_scale)
6104-
logger.info("gguf: (granite) residual_scale = %s", residual_scale)
6105-
if logits_scale := self.hparams.get("logits_scaling"):
6106-
self.gguf_writer.add_logit_scale(logits_scale)
6107-
logger.info("gguf: (granite) logits_scale = %s", logits_scale)
6108-
if shared_feed_forward_length := self.hparams.get("shared_intermediate_size"):
6109-
self.gguf_writer.add_expert_shared_feed_forward_length(shared_feed_forward_length)
6110-
logger.info("gguf: (granitemoeshared) shared_feed_forward_length = %s", shared_feed_forward_length)
6111-
if (n_experts := self.hparams.get("num_local_experts")) is not None:
6112-
self.gguf_writer.add_expert_count(n_experts)
6113-
logger.info(f"gguf: expert count = {n_experts}")
6114-
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
6115-
self.gguf_writer.add_expert_used_count(n_experts_used)
6116-
logger.info(f"gguf: experts used count = {n_experts_used}")
6076+
GraniteMoeModel.set_gguf_parameters(self)
6077+
BambaModel.set_gguf_parameters(self)
61176078

61186079
def set_vocab(self):
61196080
self.hparams["pad_vocab_size_multiple"] = 8

0 commit comments

Comments
 (0)