Skip to content

Commit e8d99dd

Browse files
gabe-l-hartDominguesMjwjohns
authored
nvidia nemotron nano v2 (nemotronh) (ggml-org#15507)
* feat: Add NEMOTRONH to python arch enum https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <[email protected]> * feat: Add NEMOTRONH to c++ arch enum https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <[email protected]> * feat: Add NEMOTRONH to llama-arch layer map https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <[email protected]> * feat: First pass at conversion for nemotronh https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <[email protected]> * feat: Add a verbose log for each tensor loaded This is really helpful for diagnosing mismatches between the expected and received tensors https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <[email protected]> * feat: First (broken) pass at nemotronh model architecture It generates tokens, just not valid ones! https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <[email protected]> * fix: Explicitly enable add_bos_token during conversion The `tokenizer.json`/`tokenizer_config.json` in the model are a bit contradictory. In the config, add_bos_token is set to False, but the tokenizer model itself has a post_processor that adds the BOS token via type: TemplateProcessing https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <[email protected]> * fix: Use relu2 (LLM_FFN_RELU_SQR) for activation in FFN layers https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <[email protected]> * fix: Only allocate attention cache for attention layers (not non-recurrent) https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <[email protected]> * fix: Move residual add to after every block https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <[email protected]> * fix: Use the correct norm tensor for the MLP blocks https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <[email protected]> * Nemotron-H: MLP gate cleanup (pass NULL for unused gate) This model does not use a gate in MLP blocks; pass NULLs for gate tensors to make intent clear and avoid unused-pointer noise. * SSM: respect ssm_dt_rank for dt_dim when provided Use GGUF-provided time_step_rank (ssm_dt_rank) to set dt_dim when > 0; fallback to max(64, n_embd/16). * fix: plamo2 - revert dt_dim to default (remove ssm_dt_rank usage) * Rename nemotronh to nemotron_h for consistency - Update architecture name from NEMOTRONH to NEMOTRON_H in constants.py - Change architecture string from 'nemotronh' to 'nemotron_h' in all files - Update enum LLM_ARCH_NEMOTRONH to LLM_ARCH_NEMOTRON_H - Update class name llm_build_nemotronh to llm_build_nemotron_h - Consistent naming with underscore convention (nemotron_h vs nemotronh) * feat: Support conversion for older NemotronH models https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <[email protected]> --------- Signed-off-by: Gabe Goodhart <[email protected]> Co-authored-by: Maicon Domingues <[email protected]> Co-authored-by: weatherman <[email protected]>
1 parent a8bca68 commit e8d99dd

File tree

7 files changed

+362
-8
lines changed

7 files changed

+362
-8
lines changed

convert_hf_to_gguf.py

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7546,9 +7546,13 @@ def __init__(self, *args, **kwargs):
75467546
]
75477547

75487548
# n_group and d_inner are used during reshape_tensors for mamba2
7549-
self.d_model = self.find_hparam(["hidden_size", "d_model"])
7550-
self.n_group = self.find_hparam(["n_groups"])
7551-
self.d_inner = self.find_hparam(["expand"]) * self.d_model
7549+
# NOTE: Explicitly include hparam prefix prefix for d_model to
7550+
# disambiguate with top-level head_dim
7551+
# NOTE 2: If needed for future models, this can be isolated in a method
7552+
# to separate the prefix setting and teh keys used
7553+
self.d_model = self.find_hparam([f"{self.hparam_prefixes[0]}_head_dim", "hidden_size", "d_model"])
7554+
self.n_group = self.find_hparam(["n_groups", "num_groups"])
7555+
self.d_inner = self.find_hparam(["expand", "num_heads"]) * self.d_model
75527556

75537557
def get_attn_layers(self):
75547558
# Explicit list of layer type names
@@ -7609,12 +7613,12 @@ def set_gguf_parameters(self):
76097613

76107614
## Mamba mixer params ##
76117615
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"]))
7612-
self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"]))
7616+
self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state", "state_dim", "ssm_state_size"]))
76137617
self.gguf_writer.add_ssm_group_count(self.n_group)
76147618
self.gguf_writer.add_ssm_inner_size(self.d_inner)
76157619
# NOTE: The mamba_dt_rank is _not_ the right field for how this is used
76167620
# in llama.cpp
7617-
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"]))
7621+
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads", "num_heads"]))
76187622

76197623
## Attention params ##
76207624
head_count_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"])
@@ -7641,6 +7645,55 @@ def set_vocab(self):
76417645
Mamba2Model.set_vocab(self)
76427646

76437647

7648+
@ModelBase.register("NemotronHForCausalLM")
7649+
class NemotronHModel(GraniteHybridModel):
7650+
"""Hybrid mamba2/attention model from NVIDIA"""
7651+
model_arch = gguf.MODEL_ARCH.NEMOTRON_H
7652+
7653+
def __init__(self, *args, **kwargs):
7654+
super().__init__(*args, **kwargs)
7655+
7656+
# Save the top-level head_dim for later
7657+
self.head_dim = self.hparams.get("head_dim", self.hparams.get("attention_head_dim"))
7658+
assert self.head_dim is not None, "Could not find the attention head dim in config"
7659+
7660+
# Don't use expand to calculate d_inner
7661+
self.d_inner = self.find_hparam(["num_heads"]) * self.d_model
7662+
7663+
# Update the ssm / attn / mlp layers
7664+
# M: Mamba2, *: Attention, -: MLP
7665+
hybrid_override_pattern = self.hparams["hybrid_override_pattern"]
7666+
self._ssm_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "M"]
7667+
self._mlp_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "-"]
7668+
7669+
def get_attn_layers(self):
7670+
hybrid_override_pattern = self.hparams["hybrid_override_pattern"]
7671+
assert len(hybrid_override_pattern) == self.block_count, "Mismatch between hybrid override and num_hidden_layers!"
7672+
return [i for i, val in enumerate(hybrid_override_pattern) if val == "*"]
7673+
7674+
def set_gguf_parameters(self):
7675+
super().set_gguf_parameters()
7676+
7677+
self.gguf_writer.add_key_length(self.head_dim)
7678+
self.gguf_writer.add_value_length(self.head_dim)
7679+
7680+
# Set feed_forward_length
7681+
# NOTE: This will trigger an override warning. This is preferrable to
7682+
# duplicating all the parent logic
7683+
n_ff = self.find_hparam(["intermediate_size", "n_inner", "hidden_dim"])
7684+
self.gguf_writer.add_feed_forward_length([
7685+
n_ff if i in self._mlp_layers else 0 for i in range(self.block_count)
7686+
])
7687+
7688+
def set_vocab(self):
7689+
super().set_vocab()
7690+
7691+
# The tokenizer _does_ add a BOS token (via post_processor type
7692+
# TemplateProcessing) but does not set add_bos_token to true in the
7693+
# config, so we need to explicitly override it here.
7694+
self.gguf_writer.add_add_bos_token(True)
7695+
7696+
76447697
@ModelBase.register("BailingMoeForCausalLM")
76457698
class BailingMoeModel(TextModel):
76467699
model_arch = gguf.MODEL_ARCH.BAILINGMOE

gguf-py/gguf/constants.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ class MODEL_ARCH(IntEnum):
367367
T5ENCODER = auto()
368368
JAIS = auto()
369369
NEMOTRON = auto()
370+
NEMOTRON_H = auto()
370371
EXAONE = auto()
371372
EXAONE4 = auto()
372373
GRANITE = auto()
@@ -700,6 +701,7 @@ class MODEL_TENSOR(IntEnum):
700701
MODEL_ARCH.T5ENCODER: "t5encoder",
701702
MODEL_ARCH.JAIS: "jais",
702703
MODEL_ARCH.NEMOTRON: "nemotron",
704+
MODEL_ARCH.NEMOTRON_H: "nemotron_h",
703705
MODEL_ARCH.EXAONE: "exaone",
704706
MODEL_ARCH.EXAONE4: "exaone4",
705707
MODEL_ARCH.GRANITE: "granite",
@@ -2297,6 +2299,25 @@ class MODEL_TENSOR(IntEnum):
22972299
MODEL_TENSOR.FFN_DOWN,
22982300
MODEL_TENSOR.FFN_UP,
22992301
],
2302+
MODEL_ARCH.NEMOTRON_H: [
2303+
MODEL_TENSOR.TOKEN_EMBD,
2304+
MODEL_TENSOR.OUTPUT_NORM,
2305+
MODEL_TENSOR.OUTPUT,
2306+
MODEL_TENSOR.ATTN_NORM,
2307+
MODEL_TENSOR.SSM_IN,
2308+
MODEL_TENSOR.SSM_CONV1D,
2309+
MODEL_TENSOR.SSM_DT,
2310+
MODEL_TENSOR.SSM_A,
2311+
MODEL_TENSOR.SSM_D,
2312+
MODEL_TENSOR.SSM_NORM,
2313+
MODEL_TENSOR.SSM_OUT,
2314+
MODEL_TENSOR.ATTN_Q,
2315+
MODEL_TENSOR.ATTN_K,
2316+
MODEL_TENSOR.ATTN_V,
2317+
MODEL_TENSOR.ATTN_OUT,
2318+
MODEL_TENSOR.FFN_DOWN,
2319+
MODEL_TENSOR.FFN_UP,
2320+
],
23002321
MODEL_ARCH.EXAONE: [
23012322
MODEL_TENSOR.TOKEN_EMBD,
23022323
MODEL_TENSOR.OUTPUT_NORM,

gguf-py/gguf/tensor_mapping.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ class TensorNameMap:
191191
"model.layers.{bid}.self_attn.q_proj", # llama4
192192
"model.transformer.blocks.{bid}.q_proj", # llada
193193
"layers.{bid}.self_attn.q_proj", # qwen3-embedding
194+
"backbone.layers.{bid}.mixer.q_proj", # nemotron-h
194195
),
195196

196197
# Attention key
@@ -209,6 +210,7 @@ class TensorNameMap:
209210
"model.layers.{bid}.self_attn.k_proj", # llama4
210211
"model.transformer.blocks.{bid}.k_proj", # llada
211212
"layers.{bid}.self_attn.k_proj", # qwen3-embedding
213+
"backbone.layers.{bid}.mixer.k_proj", # nemotron-h
212214
),
213215

214216
# Attention value
@@ -226,6 +228,7 @@ class TensorNameMap:
226228
"model.layers.{bid}.self_attn.v_proj", # llama4
227229
"model.transformer.blocks.{bid}.v_proj", # llada
228230
"layers.{bid}.self_attn.v_proj", # qwen3-embedding
231+
"backbone.layers.{bid}.mixer.v_proj", # nemotron-h
229232
),
230233

231234
# Attention output
@@ -260,6 +263,7 @@ class TensorNameMap:
260263
"transformer_encoder.{bid}.wo", # neobert
261264
"model.transformer.blocks.{bid}.attn_out", # llada
262265
"layers.{bid}.self_attn.o_proj", # qwen3-embedding
266+
"backbone.layers.{bid}.mixer.o_proj", # nemotron-h
263267
),
264268

265269
# Attention output norm
@@ -387,6 +391,7 @@ class TensorNameMap:
387391
"model.layers.{bid}.block_sparse_moe.up", # smallthinker
388392
"model.transformer.blocks.{bid}.up_proj", # llada
389393
"layers.{bid}.mlp.up_proj", # qwen3-embedding
394+
"backbone.layers.{bid}.mixer.up_proj", # nemotron-h
390395
),
391396

392397
MODEL_TENSOR.FFN_UP_EXP: (
@@ -480,6 +485,7 @@ class TensorNameMap:
480485
"model.layers.{bid}.block_sparse_moe.down", # smallthinker
481486
"model.transformer.blocks.{bid}.ff_out", # llada
482487
"layers.{bid}.mlp.down_proj", # qwen3-embedding
488+
"backbone.layers.{bid}.mixer.down_proj", # nemotron-h
483489
),
484490

485491
MODEL_TENSOR.FFN_DOWN_EXP: (

src/llama-arch.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
6969
{ LLM_ARCH_T5ENCODER, "t5encoder" },
7070
{ LLM_ARCH_JAIS, "jais" },
7171
{ LLM_ARCH_NEMOTRON, "nemotron" },
72+
{ LLM_ARCH_NEMOTRON_H, "nemotron_h" },
7273
{ LLM_ARCH_EXAONE, "exaone" },
7374
{ LLM_ARCH_EXAONE4, "exaone4" },
7475
{ LLM_ARCH_RWKV6, "rwkv6" },
@@ -1550,6 +1551,31 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
15501551
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
15511552
},
15521553
},
1554+
{
1555+
LLM_ARCH_NEMOTRON_H,
1556+
{
1557+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1558+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1559+
{ LLM_TENSOR_OUTPUT, "output" },
1560+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1561+
// mamba(2) ssm layers
1562+
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
1563+
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
1564+
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
1565+
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
1566+
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
1567+
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
1568+
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
1569+
// attention layers
1570+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1571+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1572+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1573+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1574+
// dense FFN
1575+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1576+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1577+
},
1578+
},
15531579
{
15541580
LLM_ARCH_EXAONE,
15551581
{
@@ -2355,6 +2381,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
23552381
case LLM_ARCH_PLAMO2:
23562382
case LLM_ARCH_GRANITE_HYBRID:
23572383
case LLM_ARCH_LFM2:
2384+
case LLM_ARCH_NEMOTRON_H:
23582385
return true;
23592386
default:
23602387
return false;

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ enum llm_arch {
7373
LLM_ARCH_T5ENCODER,
7474
LLM_ARCH_JAIS,
7575
LLM_ARCH_NEMOTRON,
76+
LLM_ARCH_NEMOTRON_H,
7677
LLM_ARCH_EXAONE,
7778
LLM_ARCH_EXAONE4,
7879
LLM_ARCH_RWKV6,

src/llama-model-loader.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,7 @@ const struct ggml_tensor * llama_model_loader::check_tensor_dims(const std::stri
788788
}
789789

790790
struct ggml_tensor * llama_model_loader::create_tensor(struct ggml_context * ctx, const std::string & name, const std::initializer_list<int64_t> & ne, int flags) {
791+
LLAMA_LOG_DEBUG("%s: loading tensor %s\n", __func__, name.c_str());
791792
const struct ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED));
792793

793794
if (cur == NULL) {

0 commit comments

Comments
 (0)