Skip to content

Commit b989962

Browse files
committed
override NeoBERT feed-forward length
1 parent 44e9400 commit b989962

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

convert_hf_to_gguf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -528,8 +528,6 @@ def set_gguf_parameters(self):
528528
logger.info(f"gguf: embedding length = {n_embd}")
529529

530530
if (n_ff := self.find_hparam(["intermediate_size", "n_inner", "hidden_dim"], optional=True)) is not None:
531-
if self.model_arch == gguf.MODEL_ARCH.NEO_BERT:
532-
n_ff = int(2 * n_ff / 3) # NeoBERT uses 2/3 of the intermediate size as feed forward length
533531
self.gguf_writer.add_feed_forward_length(n_ff)
534532
logger.info(f"gguf: feed forward length = {n_ff}")
535533

@@ -4085,6 +4083,8 @@ class NeoBert(BertModel):
40854083
def set_gguf_parameters(self):
40864084
super().set_gguf_parameters()
40874085

4086+
# NeoBERT uses 2/3 of the intermediate size as feed forward length
4087+
self.gguf_writer.add_feed_forward_length(int(2 * self.hparams["intermediate_size"] / 3))
40884088
self.gguf_writer.add_rope_freq_base(10000.0) # default value for NeoBERT
40894089
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
40904090

0 commit comments

Comments
 (0)