4343layer_counter = {}
4444num_model_layers = 0
4545
46- def summarize (tensor : torch .Tensor , name : str , max_seq : int = 3 , max_vals : int = 3 ):
46+
47+ def summarize (tensor : torch .Tensor , name : str , max_seq : int = 4 ):
48+ torch .set_printoptions (precision = 6 , edgeitems = max_seq , linewidth = 160 , sci_mode = False , threshold = 50 )
4749 global num_model_layers , layer_counter , token_counter
4850 """
4951 Print a tensor in llama.cpp debug style.
50-
51- Supports:
52- - 2D tensors (seq, hidden)
53- - 3D tensors (batch, seq, hidden)
54- - 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head
55- - 5D tensors
56-
5752 Shows first and last max_vals of each vector per sequence position.
5853 """
5954 t = tensor .detach ().to (torch .float32 ).cpu ()
60- ten_shape = t .shape
61- while t .ndim > 4 :
62- t = t .squeeze (0 )
63-
64- # Determine dimensions
65- if t .ndim == 3 :
66- _ , s , _ = t .shape
67- elif t .ndim == 2 :
68- _ , s = 1 , t .shape [0 ]
69- t = t .unsqueeze (0 )
70- elif t .ndim == 4 :
71- _ , s , _ , _ = t .shape
72-
73- else :
74- print (f"Skipping tensor due to unsupported dimensions: { t .ndim } " )
75- return
76-
77- print (f"ggml_debug: { name } = (f32) ... = {{{ ten_shape } }}" )
78- print (" [" )
79- print (" [" )
80-
81- # Determine indices for first and last sequences
82- first_indices = list (range (min (s , max_seq )))
83- last_indices = list (range (max (0 , s - max_seq ), s ))
8455
85- # Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq
86- has_overlap = bool (set (first_indices ) & set (last_indices )) or (max_seq * 2 == s )
87-
88- # Combine indices
89- if has_overlap :
90- # If there's overlap, just use the combined unique indices
91- indices = sorted (list (set (first_indices + last_indices )))
92- separator_index = None
93- else :
94- # If no overlap, we'll add a separator between first and last sequences
95- indices = first_indices + last_indices
96- separator_index = len (first_indices )
97-
98- for i , si in enumerate (indices ):
99- # Add separator if needed
100- if separator_index is not None and i == separator_index :
101- print (" ..." )
102-
103- # Extract appropriate slice
104- vec = t [0 , si ]
105- if vec .ndim == 2 : # 4D case: flatten heads × dim_per_head
106- flat = vec .flatten ().tolist ()
107- else : # 2D or 3D case
108- flat = vec .tolist ()
109-
110- # First and last slices
111- first = flat [:max_vals ]
112- last = flat [- max_vals :] if len (flat ) >= 2 * max_vals else flat
113- first_str = ", " .join (f"{ v :12.4f} " for v in first )
114- last_str = ", " .join (f"{ v :12.4f} " for v in last )
115-
116- if len (flat ) >= 2 * max_vals :
117- print (f" [{ first_str } , ..., { last_str } ]" )
118- else :
119- print (f" [{ last_str } ]" )
120-
121- print (" ]," )
122- print (" ]" )
123- print (f" sum = { t .sum ().item ():.6f} \n " )
56+ print (f"ggml_debug: { name } = (f32) ... = {{{ t .shape } }}\n " )
57+ print (t )
58+ print (f"\n sum = { t .sum ().item ():.6f} \n " )
12459
12560 indexed_patterns = [ r"model\.layers\.[0-9]+_out" , r"recurrent_cache_[0-9]+" ]
12661 non_indexed_patterns = [ r"k_pad" , r"v_pad" , r"q_scaled" ]
@@ -146,11 +81,41 @@ def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int =
14681 layer_counter [name ] = layer_counter [name ] + 1 # skip attention layers
14782 save_tensor (t , f"reference/tensors/org/{ name } _{ layer_counter [name ]} _{ token_counter [name ]} .bin" )
14883
149- from transformers .models .qwen3_next .modeling_qwen3_next import torch_causal_conv1d_update , apply_rotary_pos_emb , l2norm # noqa: E402
84+ from transformers .models .qwen3_next .modeling_qwen3_next import torch_causal_conv1d_update , apply_rotary_pos_emb , l2norm , repeat_kv # noqa: E402
85+ from transformers .processing_utils import Unpack # noqa: E402
86+ from transformers .utils .generic import TransformersKwargs # noqa: E402
15087orig_conv1d_update = torch_causal_conv1d_update
15188orig_rope = apply_rotary_pos_emb
15289import torch .nn .functional as F # noqa: E402
15390import typing # noqa: E402
91+ from torch import nn # noqa: E402
92+
93+ def patched_eager_attention_forward (
94+ module : nn .Module ,
95+ query : torch .Tensor ,
96+ key : torch .Tensor ,
97+ value : torch .Tensor ,
98+ attention_mask : typing .Optional [torch .Tensor ],
99+ scaling : float ,
100+ dropout : float = 0.0 ,
101+ ** kwargs : Unpack [TransformersKwargs ],
102+ ):
103+ print (f"\n Attention scaling: { scaling } \n " )
104+ key_states = repeat_kv (key , module .num_key_value_groups )
105+ value_states = repeat_kv (value , module .num_key_value_groups )
106+
107+ attn_weights = torch .matmul (query , key_states .transpose (2 , 3 )) * scaling
108+ if attention_mask is not None :
109+ causal_mask = attention_mask [:, :, :, : key_states .shape [- 2 ]]
110+ attn_weights = attn_weights + causal_mask
111+
112+ attn_weights = nn .functional .softmax (attn_weights , dim = - 1 , dtype = torch .float32 ).to (query .dtype )
113+ attn_weights = nn .functional .dropout (attn_weights , p = dropout , training = module .training )
114+ attn_output = torch .matmul (attn_weights , value_states )
115+ attn_output = attn_output .transpose (1 , 2 ).contiguous ()
116+ summarize (attn_output , "attn_output" )
117+
118+ return attn_output , attn_weights
154119
155120def patched_torch_causal_conv1d_update (
156121 hidden_states ,
@@ -343,7 +308,13 @@ def patched_torch_chunk_gated_delta_rule(
343308 summarize (k_i , f"k_i_chunk_{ i } " )
344309 summarize (v_i , f"v_i_chunk_{ i } " )
345310
346- attn = (q_i @ k_i .transpose (- 1 , - 2 ) * decay_mask [:, :, i ]).masked_fill_ (mask , 0 )
311+ q_k_trans = q_i @ k_i .transpose (- 1 , - 2 )
312+ summarize (q_k_trans , f"q_k_trans_{ i } " )
313+
314+ q_k_trans_decay = q_k_trans * decay_mask [:, :, i ]
315+ summarize (q_k_trans_decay , f"q_k_trans_decay_{ i } " )
316+
317+ attn = q_k_trans_decay .masked_fill_ (mask , 0 )
347318 summarize (attn , f"attn_chunk_{ i } " )
348319
349320 v_prime = (k_cumdecay [:, :, i ]) @ last_recurrent_state
@@ -362,16 +333,31 @@ def patched_torch_chunk_gated_delta_rule(
362333 summarize (g_last , f"g_last_chunk_{ i } " )
363334
364335 g_diff_exp = (g [:, :, i , - 1 , None ] - g [:, :, i ]).exp ()
365- last_recurrent_state = (
366- last_recurrent_state * g_last
367- + (k_i * g_diff_exp [..., None ]).transpose (- 1 , - 2 ) @ v_new
368- )
336+ summarize (g_diff_exp , f"g_diff_exp_chunk_{ i } " )
337+
338+ state_g_last = last_recurrent_state * g_last
339+ summarize (state_g_last , f"state_g_last_{ i } " )
340+
341+ k_g_diffexp = (k_i * g_diff_exp [..., None ])
342+ summarize (k_g_diffexp , f"k_g_diffexp_{ i } " )
343+
344+ k_g_diffexp_T = k_g_diffexp .transpose (- 1 , - 2 )
345+ summarize (k_g_diffexp , f"k_g_diffexp_T_{ i } " )
346+
347+ kgd_mul_vnew = k_g_diffexp_T @ v_new
348+ summarize (kgd_mul_vnew , f"kgd_mul_vnew_{ i } " )
349+
350+ last_recurrent_state = state_g_last + kgd_mul_vnew
369351 summarize (last_recurrent_state , f"updated_state_chunk_{ i } " )
370352
371353 if not output_final_state :
372354 last_recurrent_state = None
373355 core_attn_out = core_attn_out .reshape (core_attn_out .shape [0 ], core_attn_out .shape [1 ], - 1 , core_attn_out .shape [- 1 ])
356+ summarize (core_attn_out , "attn_out_reshaped" )
357+
374358 core_attn_out = core_attn_out [:, :, :num_heads ]
359+ summarize (core_attn_out , "attn_out_truncated" )
360+
375361 core_attn_out = core_attn_out .transpose (1 , 2 ).contiguous ().to (initial_dtype )
376362 summarize (core_attn_out , "attn_out" )
377363
@@ -451,6 +437,7 @@ def patched_torch_recurrent_gated_delta_rule(
451437qwen_mod .torch_causal_conv1d_update = patched_torch_causal_conv1d_update
452438qwen_mod .apply_rotary_pos_emb = patched_apply_rope
453439qwen_mod .torch_recurrent_gated_delta_rule = patched_torch_recurrent_gated_delta_rule
440+ qwen_mod .eager_attention_forward = patched_eager_attention_forward
454441
455442# Store original functions for patching
456443original_functions = {}
@@ -736,6 +723,8 @@ def save_kv_cache(past_key_values, step_num, data_dir, model_name):
736723all_generated_tokens = []
737724all_logits = []
738725
726+ model .config ._attn_implementation = "eager"
727+
739728with torch .no_grad ():
740729 # Initial forward pass
741730 print (f"\n === Initial Forward Pass ===" )
0 commit comments