Skip to content

Commit 4d571ed

Browse files
committed
Let's dump extra tensors
1 parent 2cab86a commit 4d571ed

File tree

2 files changed

+25
-11
lines changed

2 files changed

+25
-11
lines changed

examples/model-conversion/scripts/causal/run-org-model-multi-token.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@
4040
### == END ROPE DEBUG ===
4141

4242
token_counter = {}
43+
layer_counter = {}
44+
num_model_layers = 0
4345

4446
def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3):
45-
global token, token_counter
47+
global num_model_layers, layer_counter, token_counter
4648
"""
4749
Print a tensor in llama.cpp debug style.
4850
@@ -120,15 +122,27 @@ def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int =
120122
print(" ]")
121123
print(f" sum = {t.sum().item():.6f}\n")
122124

123-
pattern = r"model\.layers\.[0-9]+_out"
124-
pattern2 = r"recurrent_cache_[0-9]+"
125-
if re.fullmatch(pattern, name) or re.fullmatch(pattern2, name):
125+
indexed_patterns = [ r"model\.layers\.[0-9]+_out", r"recurrent_cache_[0-9]+" ]
126+
non_indexed_patterns = [ r"k_pad", r"v_pad", r"q_pad" ]
127+
128+
if any(re.fullmatch(p, name) for p in indexed_patterns):
126129
if name not in token_counter:
127130
token_counter[name] = 1
128131
else:
129132
token_counter[name] = token_counter[name] + 1
130133
save_tensor(t, f"reference/tensors/org/{name}_{token_counter[name]}.bin")
131134

135+
if any(re.fullmatch(p, name) for p in non_indexed_patterns):
136+
if name not in token_counter:
137+
token_counter[name] = 1
138+
else:
139+
token_counter[name] = token_counter[name] + 1
140+
if name not in layer_counter or layer_counter[name] == num_model_layers - 1:
141+
layer_counter[name] = 0
142+
else:
143+
layer_counter[name] = layer_counter[name] + 1
144+
save_tensor(t, f"reference/tensors/org/{name}_{layer_counter[name] - 1}_{token_counter[name]}.bin")
145+
132146
from transformers.models.qwen3_next.modeling_qwen3_next import torch_causal_conv1d_update, apply_rotary_pos_emb, l2norm # noqa: E402
133147
orig_conv1d_update = torch_causal_conv1d_update
134148
orig_rope = apply_rotary_pos_emb
@@ -223,10 +237,8 @@ def patched_torch_chunk_gated_delta_rule(
223237
chunk_size=64,
224238
initial_state=None,
225239
output_final_state=False,
226-
use_qk_l2norm_in_kernel=False,
227-
long=False
240+
use_qk_l2norm_in_kernel=False
228241
):
229-
torch.set_printoptions(threshold=10_000_000, sci_mode=False, precision=10, linewidth=200)
230242
initial_dtype = query.dtype
231243
[ summarize(x, y) for (x, y) in ((query, "q_prenorm"), (key, "k_prenorm")) ]
232244
if use_qk_l2norm_in_kernel:
@@ -359,13 +371,10 @@ def patched_torch_chunk_gated_delta_rule(
359371
core_attn_out = core_attn_out[:, :, :num_heads]
360372
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
361373
summarize(core_attn_out, "attn_out")
362-
if not long:
363-
print(f"attn_out:\n{core_attn_out}\n\n")
364374

365375
if isinstance(last_recurrent_state, torch.Tensor):
366376
summarize(last_recurrent_state, "state_out")
367-
if not long:
368-
print(f"state_out:\n{last_recurrent_state}\n\n")
377+
369378
return core_attn_out, last_recurrent_state
370379

371380

@@ -667,6 +676,8 @@ def save_kv_cache(past_key_values, step_num, data_dir, model_name):
667676
print("BOS token id: ", config.bos_token_id)
668677
print("EOS token id: ", config.eos_token_id)
669678

679+
num_model_layers = config.num_hidden_layers
680+
670681
print("Loading model and tokenizer using AutoTokenizer:", model_path)
671682
tokenizer = AutoTokenizer.from_pretrained(model_path)
672683
config = AutoConfig.from_pretrained(model_path)

tools/main/main.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,9 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) {
246246
uint8_t * data = is_host ? (uint8_t *) t->data : cb_data->data.data();
247247
std::string tensor_name(t->name);
248248
if (std::string(tensor_name).substr(0, std::string("post_moe-").size()) == "post_moe-" ||
249+
std::string(tensor_name).substr(0, std::string("k_pad-").size()) == "k_pad-" ||
250+
std::string(tensor_name).substr(0, std::string("q_pad-").size()) == "q_pad-" ||
251+
std::string(tensor_name).substr(0, std::string("v_pad-").size()) == "v_pad-" ||
249252
std::string(tensor_name).substr(0, std::string("state_1d-").size()) == "state_1d-") {
250253

251254
if (cb_data->tensors.count(tensor_name) == 0) {

0 commit comments

Comments
 (0)