Skip to content

Commit c207514

Browse files
committed
refactor(py): Simplify granitemoehybrid conversion to use parents better
Branch: GraniteFour Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 346f546 commit c207514

File tree

1 file changed

+7
-46
lines changed

1 file changed

+7
-46
lines changed

convert_hf_to_gguf.py

Lines changed: 7 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6055,10 +6055,6 @@ class GraniteMoeHybridModel(BambaModel, GraniteMoeModel):
60556055
SSM layers"""
60566056
model_arch = gguf.MODEL_ARCH.GRANITE_MOE_HYBRID
60576057

6058-
def __init__(self, *args, **kwargs):
6059-
super().__init__(*args, **kwargs)
6060-
self._transformer_model_class = GraniteMoeModel
6061-
60626058
def get_attn_layres(self):
60636059
if layer_types := self.hparams.get("layer_types"):
60646060
return [
@@ -6070,51 +6066,16 @@ def get_attn_layres(self):
60706066
def modify_tensors(
60716067
self, data_torch: Tensor, name: str, bid: int | None
60726068
) -> Iterable[tuple[str, Tensor]]:
6073-
6074-
# In GraniteMoeHybrid, the mamba layers also have an MoE + Shared expert
6075-
if name.endswith("block_sparse_moe.input_linear.weight"):
6076-
ffn_dim = self.hparams["intermediate_size"]
6077-
assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * intermediate_size"
6078-
gate, up = data_torch[..., :ffn_dim, :], data_torch[..., ffn_dim:, :]
6079-
return [
6080-
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXP, bid), gate),
6081-
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), up),
6082-
]
6083-
if name.endswith("shared_mlp.input_linear.weight"):
6084-
ffn_dim = self.hparams["shared_intermediate_size"]
6085-
assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * shared_intermediate_size"
6086-
gate, up = data_torch.split(ffn_dim, dim=-2)
6087-
return [
6088-
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SHEXP, bid), gate),
6089-
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SHEXP, bid), up),
6090-
]
6091-
6069+
if (
6070+
name.endswith("block_sparse_moe.input_linear.weight") or
6071+
name.endswith("shared_mlp.input_linear.weight")
6072+
):
6073+
return GraniteMoeModel.modify_tensors(self, data_torch, name, bid)
60926074
return super().modify_tensors(data_torch, name, bid)
60936075

6094-
60956076
def set_gguf_parameters(self):
6096-
super().set_gguf_parameters()
6097-
if attention_scale := self.hparams.get("attention_multiplier"):
6098-
self.gguf_writer.add_attention_scale(attention_scale)
6099-
logger.info("gguf: (granite) attention_scale = %s", attention_scale)
6100-
if embedding_scale := self.hparams.get("embedding_multiplier"):
6101-
self.gguf_writer.add_embedding_scale(embedding_scale)
6102-
logger.info("gguf: (granite) embedding_scale = %s", embedding_scale)
6103-
if residual_scale := self.hparams.get("residual_multiplier"):
6104-
self.gguf_writer.add_residual_scale(residual_scale)
6105-
logger.info("gguf: (granite) residual_scale = %s", residual_scale)
6106-
if logits_scale := self.hparams.get("logits_scaling"):
6107-
self.gguf_writer.add_logit_scale(logits_scale)
6108-
logger.info("gguf: (granite) logits_scale = %s", logits_scale)
6109-
if shared_feed_forward_length := self.hparams.get("shared_intermediate_size"):
6110-
self.gguf_writer.add_expert_shared_feed_forward_length(shared_feed_forward_length)
6111-
logger.info("gguf: (granitemoeshared) shared_feed_forward_length = %s", shared_feed_forward_length)
6112-
if (n_experts := self.hparams.get("num_local_experts")) is not None:
6113-
self.gguf_writer.add_expert_count(n_experts)
6114-
logger.info(f"gguf: expert count = {n_experts}")
6115-
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
6116-
self.gguf_writer.add_expert_used_count(n_experts_used)
6117-
logger.info(f"gguf: experts used count = {n_experts_used}")
6077+
GraniteMoeModel.set_gguf_parameters(self)
6078+
BambaModel.set_gguf_parameters(self)
61186079

61196080
def set_vocab(self):
61206081
self.hparams["pad_vocab_size_multiple"] = 8

0 commit comments

Comments
 (0)