Skip to content

Commit cc9b929

Browse files
committed
still isnt working though progress is being made
1 parent 62accf9 commit cc9b929

File tree

2 files changed

+20
-10
lines changed

2 files changed

+20
-10
lines changed

convert_hf_to_gguf.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7982,15 +7982,21 @@ def modify_tensors(self, data_torch, name, bid):
79827982
# Special handling for conv1d: reshape from 3D to 2D
79837983
if "conv1d.weight" in layer_component and len(data_torch.shape) == 3:
79847984
data_torch = data_torch.squeeze(1) # Remove middle dimension: {4,1,12288} -> {4,12288}
7985-
# A_log -> A = -exp(A_log) and reshape from [128,1,1,1] to [1,128]
7985+
# A_log -> A = -exp(A_log) and ensure [1,128] shape for llama.cpp
79867986
if layer_component.endswith("A_log"):
79877987
data_torch = -torch.exp(data_torch)
7988-
if len(data_torch.shape) == 4 and data_torch.shape[1:] == (1, 1, 1):
7989-
data_torch = data_torch.reshape(1, data_torch.shape[0]) # [128,1,1,1] -> [1,128]
7990-
# D tensor also needs reshaping from [128,1,1,1] to [1,128]
7988+
# Ensure 2D shape [1, d_state] for llama.cpp compatibility
7989+
if len(data_torch.shape) == 1:
7990+
data_torch = data_torch.unsqueeze(-1) # [128] -> [128,1] -> store as [1,128] in GGUF
7991+
elif len(data_torch.shape) == 4 and data_torch.shape[1:] == (1, 1, 1):
7992+
data_torch = data_torch.reshape(data_torch.shape[0], 1) # [128,1,1,1] -> [128,1]
7993+
# D tensor also needs reshaping to [1,128] for llama.cpp
79917994
if layer_component.endswith("D"):
7992-
if len(data_torch.shape) == 4 and data_torch.shape[1:] == (1, 1, 1):
7993-
data_torch = data_torch.reshape(1, data_torch.shape[0]) # [128,1,1,1] -> [1,128]
7995+
# Ensure 2D shape [1, d_state] for llama.cpp compatibility
7996+
if len(data_torch.shape) == 1:
7997+
data_torch = data_torch.unsqueeze(-1) # [128] -> [128,1] -> store as [1,128] in GGUF
7998+
elif len(data_torch.shape) == 4 and data_torch.shape[1:] == (1, 1, 1):
7999+
data_torch = data_torch.reshape(data_torch.shape[0], 1) # [128,1,1,1] -> [128,1]
79948000
# Grouped RMSNorm reshape to [actual_size/n_group, n_group]
79958001
if layer_component == "mixer.norm.weight":
79968002
actual_size = data_torch.numel()

src/llama-model.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3798,10 +3798,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
37983798
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, d_state}, 0);
37993799
layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, d_state}, 0);
38003800

3801-
// grouped RMSNorm for the SSM inner stream
3802-
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0);
3803-
// out_proj back to model dim
3804-
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
3801+
// grouped RMSNorm for the SSM inner stream (actual tensor size is 10240 not d_inner)
3802+
// Nemotron-H norm tensor: 10240 elements reshaped to [1280, 8]
3803+
const int64_t norm_elements_per_group = 1280; // 10240 / 8
3804+
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {norm_elements_per_group, n_group}, 0);
3805+
// out_proj back to model dim (actual tensor is [4480, 10240] not [15680, 4480])
3806+
// Nemotron-H out_proj: 10240 -> 4480 (not d_inner -> n_embd)
3807+
const int64_t out_proj_input_dim = 10240; // Actual SSM output dim
3808+
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {out_proj_input_dim, n_embd}, 0);
38053809
} else if (is_attention_layer) {
38063810
// Attention layer tensors - compute from heads and head dim
38073811
const int64_t n_head_i = 40; // q heads

0 commit comments

Comments
 (0)