|
40 | 40 | ### == END ROPE DEBUG === |
41 | 41 |
|
42 | 42 | token_counter = {} |
| 43 | +layer_counter = {} |
| 44 | +num_model_layers = 0 |
43 | 45 |
|
44 | 46 | def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3): |
45 | | - global token, token_counter |
| 47 | + global num_model_layers, layer_counter, token_counter |
46 | 48 | """ |
47 | 49 | Print a tensor in llama.cpp debug style. |
48 | 50 |
|
@@ -120,15 +122,27 @@ def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = |
120 | 122 | print(" ]") |
121 | 123 | print(f" sum = {t.sum().item():.6f}\n") |
122 | 124 |
|
123 | | - pattern = r"model\.layers\.[0-9]+_out" |
124 | | - pattern2 = r"recurrent_cache_[0-9]+" |
125 | | - if re.fullmatch(pattern, name) or re.fullmatch(pattern2, name): |
| 125 | + indexed_patterns = [ r"model\.layers\.[0-9]+_out", r"recurrent_cache_[0-9]+" ] |
| 126 | + non_indexed_patterns = [ r"k_pad", r"v_pad", r"q_pad" ] |
| 127 | + |
| 128 | + if any(re.fullmatch(p, name) for p in indexed_patterns): |
126 | 129 | if name not in token_counter: |
127 | 130 | token_counter[name] = 1 |
128 | 131 | else: |
129 | 132 | token_counter[name] = token_counter[name] + 1 |
130 | 133 | save_tensor(t, f"reference/tensors/org/{name}_{token_counter[name]}.bin") |
131 | 134 |
|
| 135 | + if any(re.fullmatch(p, name) for p in non_indexed_patterns): |
| 136 | + if name not in token_counter: |
| 137 | + token_counter[name] = 1 |
| 138 | + else: |
| 139 | + token_counter[name] = token_counter[name] + 1 |
| 140 | + if name not in layer_counter or layer_counter[name] == num_model_layers - 1: |
| 141 | + layer_counter[name] = 0 |
| 142 | + else: |
| 143 | + layer_counter[name] = layer_counter[name] + 1 |
| 144 | + save_tensor(t, f"reference/tensors/org/{name}_{layer_counter[name] - 1}_{token_counter[name]}.bin") |
| 145 | + |
132 | 146 | from transformers.models.qwen3_next.modeling_qwen3_next import torch_causal_conv1d_update, apply_rotary_pos_emb, l2norm # noqa: E402 |
133 | 147 | orig_conv1d_update = torch_causal_conv1d_update |
134 | 148 | orig_rope = apply_rotary_pos_emb |
@@ -223,10 +237,8 @@ def patched_torch_chunk_gated_delta_rule( |
223 | 237 | chunk_size=64, |
224 | 238 | initial_state=None, |
225 | 239 | output_final_state=False, |
226 | | - use_qk_l2norm_in_kernel=False, |
227 | | - long=False |
| 240 | + use_qk_l2norm_in_kernel=False |
228 | 241 | ): |
229 | | - torch.set_printoptions(threshold=10_000_000, sci_mode=False, precision=10, linewidth=200) |
230 | 242 | initial_dtype = query.dtype |
231 | 243 | [ summarize(x, y) for (x, y) in ((query, "q_prenorm"), (key, "k_prenorm")) ] |
232 | 244 | if use_qk_l2norm_in_kernel: |
@@ -359,13 +371,10 @@ def patched_torch_chunk_gated_delta_rule( |
359 | 371 | core_attn_out = core_attn_out[:, :, :num_heads] |
360 | 372 | core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) |
361 | 373 | summarize(core_attn_out, "attn_out") |
362 | | - if not long: |
363 | | - print(f"attn_out:\n{core_attn_out}\n\n") |
364 | 374 |
|
365 | 375 | if isinstance(last_recurrent_state, torch.Tensor): |
366 | 376 | summarize(last_recurrent_state, "state_out") |
367 | | - if not long: |
368 | | - print(f"state_out:\n{last_recurrent_state}\n\n") |
| 377 | + |
369 | 378 | return core_attn_out, last_recurrent_state |
370 | 379 |
|
371 | 380 |
|
@@ -667,6 +676,8 @@ def save_kv_cache(past_key_values, step_num, data_dir, model_name): |
667 | 676 | print("BOS token id: ", config.bos_token_id) |
668 | 677 | print("EOS token id: ", config.eos_token_id) |
669 | 678 |
|
| 679 | +num_model_layers = config.num_hidden_layers |
| 680 | + |
670 | 681 | print("Loading model and tokenizer using AutoTokenizer:", model_path) |
671 | 682 | tokenizer = AutoTokenizer.from_pretrained(model_path) |
672 | 683 | config = AutoConfig.from_pretrained(model_path) |
|
0 commit comments