Skip to content

Commit 5417f32

Browse files
committed
Wrong dimension order
1 parent e5ffc91 commit 5417f32

File tree

6 files changed

+10484
-80
lines changed

6 files changed

+10484
-80
lines changed

comp.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#!/bin/bash
2+
echo "Running converted model."
3+
llama-cli -no-cnv -m reference/qwen3_ntl/qwen3_ntl.gguf -p "Once upon a time" -n 30 --temp 0 &> data/tinylong-30-tok.txt
4+
echo "Running original model."
5+
python examples/model-conversion/scripts/causal/run-org-model-multi-token.py --model-path reference/qwen3_ntl --num-tokens 30 --prompt "Once upon a time" &> data/tinylong-30-tok-org.txt
6+
echo "Running tensor comparison."
7+
python reference/compare_tensors.py 30 16 &> data/tinylong-30-compare.txt
8+
echo "Done."

examples/model-conversion/scripts/causal/run-org-model-multi-token.py

Lines changed: 66 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -43,84 +43,19 @@
4343
layer_counter = {}
4444
num_model_layers = 0
4545

46-
def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3):
46+
47+
def summarize(tensor: torch.Tensor, name: str, max_seq: int = 4):
48+
torch.set_printoptions(precision = 6, edgeitems = max_seq, linewidth = 160, sci_mode = False, threshold = 50)
4749
global num_model_layers, layer_counter, token_counter
4850
"""
4951
Print a tensor in llama.cpp debug style.
50-
51-
Supports:
52-
- 2D tensors (seq, hidden)
53-
- 3D tensors (batch, seq, hidden)
54-
- 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head
55-
- 5D tensors
56-
5752
Shows first and last max_vals of each vector per sequence position.
5853
"""
5954
t = tensor.detach().to(torch.float32).cpu()
60-
ten_shape = t.shape
61-
while t.ndim > 4:
62-
t = t.squeeze(0)
63-
64-
# Determine dimensions
65-
if t.ndim == 3:
66-
_, s, _ = t.shape
67-
elif t.ndim == 2:
68-
_, s = 1, t.shape[0]
69-
t = t.unsqueeze(0)
70-
elif t.ndim == 4:
71-
_, s, _, _ = t.shape
72-
73-
else:
74-
print(f"Skipping tensor due to unsupported dimensions: {t.ndim}")
75-
return
76-
77-
print(f"ggml_debug: {name} = (f32) ... = {{{ten_shape}}}")
78-
print(" [")
79-
print(" [")
80-
81-
# Determine indices for first and last sequences
82-
first_indices = list(range(min(s, max_seq)))
83-
last_indices = list(range(max(0, s - max_seq), s))
8455

85-
# Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq
86-
has_overlap = bool(set(first_indices) & set(last_indices)) or (max_seq * 2 == s)
87-
88-
# Combine indices
89-
if has_overlap:
90-
# If there's overlap, just use the combined unique indices
91-
indices = sorted(list(set(first_indices + last_indices)))
92-
separator_index = None
93-
else:
94-
# If no overlap, we'll add a separator between first and last sequences
95-
indices = first_indices + last_indices
96-
separator_index = len(first_indices)
97-
98-
for i, si in enumerate(indices):
99-
# Add separator if needed
100-
if separator_index is not None and i == separator_index:
101-
print(" ...")
102-
103-
# Extract appropriate slice
104-
vec = t[0, si]
105-
if vec.ndim == 2: # 4D case: flatten heads × dim_per_head
106-
flat = vec.flatten().tolist()
107-
else: # 2D or 3D case
108-
flat = vec.tolist()
109-
110-
# First and last slices
111-
first = flat[:max_vals]
112-
last = flat[-max_vals:] if len(flat) >= 2 * max_vals else flat
113-
first_str = ", ".join(f"{v:12.4f}" for v in first)
114-
last_str = ", ".join(f"{v:12.4f}" for v in last)
115-
116-
if len(flat) >= 2 * max_vals:
117-
print(f" [{first_str}, ..., {last_str}]")
118-
else:
119-
print(f" [{last_str}]")
120-
121-
print(" ],")
122-
print(" ]")
123-
print(f" sum = {t.sum().item():.6f}\n")
56+
print(f"ggml_debug: {name} = (f32) ... = {{{t.shape}}}\n")
57+
print(t)
58+
print(f"\n sum = {t.sum().item():.6f}\n")
12459

12560
indexed_patterns = [ r"model\.layers\.[0-9]+_out", r"recurrent_cache_[0-9]+" ]
12661
non_indexed_patterns = [ r"k_pad", r"v_pad", r"q_scaled" ]
@@ -146,11 +81,41 @@ def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int =
14681
layer_counter[name] = layer_counter[name] + 1 # skip attention layers
14782
save_tensor(t, f"reference/tensors/org/{name}_{layer_counter[name]}_{token_counter[name]}.bin")
14883

149-
from transformers.models.qwen3_next.modeling_qwen3_next import torch_causal_conv1d_update, apply_rotary_pos_emb, l2norm # noqa: E402
84+
from transformers.models.qwen3_next.modeling_qwen3_next import torch_causal_conv1d_update, apply_rotary_pos_emb, l2norm, repeat_kv # noqa: E402
85+
from transformers.processing_utils import Unpack # noqa: E402
86+
from transformers.utils.generic import TransformersKwargs # noqa: E402
15087
orig_conv1d_update = torch_causal_conv1d_update
15188
orig_rope = apply_rotary_pos_emb
15289
import torch.nn.functional as F # noqa: E402
15390
import typing # noqa: E402
91+
from torch import nn # noqa: E402
92+
93+
def patched_eager_attention_forward(
94+
module: nn.Module,
95+
query: torch.Tensor,
96+
key: torch.Tensor,
97+
value: torch.Tensor,
98+
attention_mask: typing.Optional[torch.Tensor],
99+
scaling: float,
100+
dropout: float = 0.0,
101+
**kwargs: Unpack[TransformersKwargs],
102+
):
103+
print(f"\nAttention scaling: {scaling}\n")
104+
key_states = repeat_kv(key, module.num_key_value_groups)
105+
value_states = repeat_kv(value, module.num_key_value_groups)
106+
107+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
108+
if attention_mask is not None:
109+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
110+
attn_weights = attn_weights + causal_mask
111+
112+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
113+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
114+
attn_output = torch.matmul(attn_weights, value_states)
115+
attn_output = attn_output.transpose(1, 2).contiguous()
116+
summarize(attn_output, "attn_output")
117+
118+
return attn_output, attn_weights
154119

155120
def patched_torch_causal_conv1d_update(
156121
hidden_states,
@@ -343,7 +308,13 @@ def patched_torch_chunk_gated_delta_rule(
343308
summarize(k_i, f"k_i_chunk_{i}")
344309
summarize(v_i, f"v_i_chunk_{i}")
345310

346-
attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
311+
q_k_trans = q_i @ k_i.transpose(-1, -2)
312+
summarize(q_k_trans, f"q_k_trans_{i}")
313+
314+
q_k_trans_decay = q_k_trans * decay_mask[:, :, i]
315+
summarize(q_k_trans_decay, f"q_k_trans_decay_{i}")
316+
317+
attn = q_k_trans_decay.masked_fill_(mask, 0)
347318
summarize(attn, f"attn_chunk_{i}")
348319

349320
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
@@ -362,16 +333,31 @@ def patched_torch_chunk_gated_delta_rule(
362333
summarize(g_last, f"g_last_chunk_{i}")
363334

364335
g_diff_exp = (g[:, :, i, -1, None] - g[:, :, i]).exp()
365-
last_recurrent_state = (
366-
last_recurrent_state * g_last
367-
+ (k_i * g_diff_exp[..., None]).transpose(-1, -2) @ v_new
368-
)
336+
summarize(g_diff_exp, f"g_diff_exp_chunk_{i}")
337+
338+
state_g_last = last_recurrent_state * g_last
339+
summarize(state_g_last, f"state_g_last_{i}")
340+
341+
k_g_diffexp = (k_i * g_diff_exp[..., None])
342+
summarize(k_g_diffexp, f"k_g_diffexp_{i}")
343+
344+
k_g_diffexp_T = k_g_diffexp.transpose(-1, -2)
345+
summarize(k_g_diffexp, f"k_g_diffexp_T_{i}")
346+
347+
kgd_mul_vnew = k_g_diffexp_T @ v_new
348+
summarize(kgd_mul_vnew, f"kgd_mul_vnew_{i}")
349+
350+
last_recurrent_state = state_g_last + kgd_mul_vnew
369351
summarize(last_recurrent_state, f"updated_state_chunk_{i}")
370352

371353
if not output_final_state:
372354
last_recurrent_state = None
373355
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
356+
summarize(core_attn_out, "attn_out_reshaped")
357+
374358
core_attn_out = core_attn_out[:, :, :num_heads]
359+
summarize(core_attn_out, "attn_out_truncated")
360+
375361
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
376362
summarize(core_attn_out, "attn_out")
377363

@@ -451,6 +437,7 @@ def patched_torch_recurrent_gated_delta_rule(
451437
qwen_mod.torch_causal_conv1d_update = patched_torch_causal_conv1d_update
452438
qwen_mod.apply_rotary_pos_emb = patched_apply_rope
453439
qwen_mod.torch_recurrent_gated_delta_rule = patched_torch_recurrent_gated_delta_rule
440+
qwen_mod.eager_attention_forward = patched_eager_attention_forward
454441

455442
# Store original functions for patching
456443
original_functions = {}
@@ -736,6 +723,8 @@ def save_kv_cache(past_key_values, step_num, data_dir, model_name):
736723
all_generated_tokens = []
737724
all_logits = []
738725

726+
model.config._attn_implementation = "eager"
727+
739728
with torch.no_grad():
740729
# Initial forward pass
741730
print(f"\n=== Initial Forward Pass ===")

ggml/src/ggml-cpu/ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11406,7 +11406,7 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
1140611406
*(state_ptr(seq, head, i, j)) = temp_state[state_idx];
1140711407

1140811408
// Store the final state for this head and sequence (for output)
11409-
int64_t final_state_idx = i + j * S_v + head * (S_v * S_v) + seq * (S_v * S_v * H_v);
11409+
int64_t final_state_idx = j + i * S_v + head * (S_v * S_v) + seq * (S_v * S_v * H_v);
1141011410
final_state[final_state_idx] = temp_state[state_idx];
1141111411
}
1141211412
}

src/llama-graph.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1532,7 +1532,7 @@ ggml_tensor * llm_graph_context::build_attn(
15321532

15331533
if (wo) {
15341534
cur = build_lora_mm(wo, cur);
1535-
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_QWEN3NEXT) {
1535+
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
15361536
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
15371537
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
15381538
}

0 commit comments

Comments
 (0)