Skip to content

Commit ca4c978

Browse files
committed
resolving tensor dimensions
1 parent 154459a commit ca4c978

File tree

2 files changed

+39
-32
lines changed

2 files changed

+39
-32
lines changed

convert_hf_to_gguf.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7979,28 +7979,29 @@ def modify_tensors(self, data_torch, name, bid):
79797979
elif any(x in layer_component for x in ["A_log", "D", "conv1d", "dt_bias", "in_proj", "mixer.norm", "out_proj"]):
79807980
# Mamba layer tensors (note: mixer.norm, not just norm.weight)
79817981
new_name = self._map_mamba_tensor(layer_component, bid)
7982-
# Special handling for conv1d: reshape from 3D to 2D
7983-
if "conv1d.weight" in layer_component and len(data_torch.shape) == 3:
7984-
data_torch = data_torch.squeeze(1) # Remove middle dimension: {4,1,12288} -> {4,12288}
7985-
# A_log -> A = -exp(A_log) and ensure [1,128] shape for llama.cpp
7982+
# NVIDIA GROUND TRUTH TENSOR TRANSFORMATIONS
7983+
7984+
# Conv1d: NVIDIA [12288, 1, 4] -> llama.cpp [4, 12288]
7985+
if "conv1d.weight" in layer_component:
7986+
if len(data_torch.shape) == 3: # [12288, 1, 4]
7987+
data_torch = data_torch.squeeze(1).t().contiguous() # [12288, 4] -> [4, 12288]
7988+
7989+
# A_log: NVIDIA [128] -> llama.cpp [1, 128] with -exp transform
79867990
if layer_component.endswith("A_log"):
7987-
data_torch = -torch.exp(data_torch)
7988-
# Ensure 2D shape [1, d_state] for llama.cpp compatibility
7989-
if len(data_torch.shape) == 1:
7990-
data_torch = data_torch.unsqueeze(-1) # [128] -> [128,1] -> store as [1,128] in GGUF
7991-
elif len(data_torch.shape) == 4 and data_torch.shape[1:] == (1, 1, 1):
7992-
data_torch = data_torch.reshape(data_torch.shape[0], 1) # [128,1,1,1] -> [128,1]
7993-
# D tensor also needs reshaping to [1,128] for llama.cpp
7991+
data_torch = -torch.exp(data_torch) # Apply -exp transformation
7992+
if len(data_torch.shape) == 1: # [128]
7993+
data_torch = data_torch.unsqueeze(0) # -> [1, 128]
7994+
7995+
# D: NVIDIA [128] -> llama.cpp [1, 128]
79947996
if layer_component.endswith("D"):
7995-
# Ensure 2D shape [1, d_state] for llama.cpp compatibility
7996-
if len(data_torch.shape) == 1:
7997-
data_torch = data_torch.unsqueeze(-1) # [128] -> [128,1] -> store as [1,128] in GGUF
7998-
elif len(data_torch.shape) == 4 and data_torch.shape[1:] == (1, 1, 1):
7999-
data_torch = data_torch.reshape(data_torch.shape[0], 1) # [128,1,1,1] -> [128,1]
8000-
# Grouped RMSNorm reshape to [actual_size/n_group, n_group]
7997+
if len(data_torch.shape) == 1: # [128]
7998+
data_torch = data_torch.unsqueeze(0) # -> [1, 128]
7999+
8000+
# Grouped RMSNorm: NVIDIA [10240] -> llama.cpp [1280, 8]
80018001
if layer_component == "mixer.norm.weight":
8002-
actual_size = data_torch.numel()
8003-
data_torch = data_torch.reshape(actual_size // self.n_group, self.n_group)
8002+
if len(data_torch.shape) == 1: # [10240]
8003+
# 10240 elements = 1280 * 8 groups
8004+
data_torch = data_torch.reshape(1280, 8)
80048005
# in_proj needs split order expected by llama.cpp mamba2 builder: [z, xBC, dt]
80058006
if layer_component == "mixer.in_proj.weight":
80068007
W = data_torch

src/llama-model.cpp

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3773,8 +3773,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
37733773

37743774
for (int i = 0; i < n_layer; ++i) {
37753775
auto & layer = layers[i];
3776-
bool is_mamba_layer = hparams.is_recurrent(i);
3777-
bool is_attention_layer = (i == 14 || i == 21 || i == 30 || i == 39); // Known attention layers for Nemotron-H 9B
3776+
// Nemotron-H 9B ground truth layer structure (56 total layers):
3777+
// 27 SSM layers: [0,2,4,6,7,9,11,13,16,18,20,23,25,27,29,32,34,36,38,41,43,44,46,48,50,52,54]
3778+
// 25 MLP layers: [1,3,5,8,10,12,15,17,19,22,24,26,28,31,33,35,37,40,42,45,47,49,51,53,55]
3779+
// 4 Attention layers: [14,21,30,39]
3780+
std::vector<int> ssm_layers = {0,2,4,6,7,9,11,13,16,18,20,23,25,27,29,32,34,36,38,41,43,44,46,48,50,52,54};
3781+
std::vector<int> attention_layers = {14,21,30,39};
3782+
3783+
bool is_mamba_layer = std::find(ssm_layers.begin(), ssm_layers.end(), i) != ssm_layers.end();
3784+
bool is_attention_layer = std::find(attention_layers.begin(), attention_layers.end(), i) != attention_layers.end();
37783785

37793786
// norm (all layers have this)
37803787
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
@@ -3784,24 +3791,23 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
37843791
// in_proj packs [x1, B, C, x2, dt_hat] in this kernel order
37853792
layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0);
37863793

3787-
// depthwise conv over the first partition (x1 only, not full x1+B+C)
3788-
// Nemotron-H conv1d dims: 12288 (not the full d_x_part = 17728)
3794+
// depthwise conv: GGUF has {12288, 4} due to conversion - adapt to ground truth
3795+
// NVIDIA ground truth: [12288, 1, 4] -> GGUF: {12288, 4}
37893796
const int64_t nemotron_conv_dim = 12288;
3790-
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, nemotron_conv_dim}, 0);
3797+
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {nemotron_conv_dim, d_conv}, 0);
37913798
layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {nemotron_conv_dim}, 0);
37923799

37933800
// time step bias for low-rank delta
37943801
layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {d_state}, 0); // Use d_state (128) not n_head (80)
37953802

37963803
// SSM decay and skip parameters per SSM state dimension
3797-
// Nemotron-H uses d_state (128) not dt_rank (122) for A and D tensors
3798-
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, d_state}, 0);
3799-
layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, d_state}, 0);
3800-
3801-
// grouped RMSNorm for the SSM inner stream (actual tensor size is 10240 not d_inner)
3802-
// Nemotron-H norm tensor: 10240 elements reshaped to [1280, 8]
3803-
const int64_t norm_elements_per_group = 1280; // 10240 / 8
3804-
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {norm_elements_per_group, n_group}, 0);
3804+
// Nemotron-H: GGUF has A,D as {128, 1} due to conversion - adapt to ground truth
3805+
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, 1}, 0);
3806+
layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_state, 1}, 0);
3807+
3808+
// grouped RMSNorm: GGUF has {8, 1280} due to conversion - adapt to ground truth
3809+
// 10240 total elements grouped as 8 groups of 1280 elements each
3810+
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {n_group, 1280}, 0);
38053811
// out_proj back to model dim (actual tensor is [4480, 10240] not [15680, 4480])
38063812
// Nemotron-H out_proj: 10240 -> 4480 (not d_inner -> n_embd)
38073813
const int64_t out_proj_input_dim = 10240; // Actual SSM output dim

0 commit comments

Comments
 (0)