4343layer_counter  =  {}
4444num_model_layers  =  0 
4545
46- def  summarize (tensor : torch .Tensor , name : str , max_seq : int  =  3 , max_vals : int  =  3 ):
46+ 
47+ def  summarize (tensor : torch .Tensor , name : str , max_seq : int  =  4 ):
48+     torch .set_printoptions (precision  =  6 , edgeitems  =  max_seq , linewidth  =  160 , sci_mode  =  False , threshold  =  50 )
4749    global  num_model_layers , layer_counter , token_counter 
4850    """ 
4951    Print a tensor in llama.cpp debug style. 
50- 
51-     Supports: 
52-     - 2D tensors (seq, hidden) 
53-     - 3D tensors (batch, seq, hidden) 
54-     - 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head 
55-     - 5D tensors 
56- 
5752    Shows first and last max_vals of each vector per sequence position. 
5853    """ 
5954    t  =  tensor .detach ().to (torch .float32 ).cpu ()
60-     ten_shape  =  t .shape 
61-     while  t .ndim  >  4 :
62-         t  =  t .squeeze (0 )
63- 
64-     # Determine dimensions 
65-     if  t .ndim  ==  3 :
66-         _ , s , _  =  t .shape 
67-     elif  t .ndim  ==  2 :
68-         _ , s  =  1 , t .shape [0 ]
69-         t  =  t .unsqueeze (0 )
70-     elif  t .ndim  ==  4 :
71-         _ , s , _ , _  =  t .shape 
72- 
73-     else :
74-         print (f"Skipping tensor due to unsupported dimensions: { t .ndim }  " )
75-         return 
76- 
77-     print (f"ggml_debug: { name }   = (f32)  ... = {{{ ten_shape }  }}" )
78-     print ("                                     [" )
79-     print ("                                      [" )
80- 
81-     # Determine indices for first and last sequences 
82-     first_indices  =  list (range (min (s , max_seq )))
83-     last_indices  =  list (range (max (0 , s  -  max_seq ), s ))
8455
85-     # Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq 
86-     has_overlap  =  bool (set (first_indices ) &  set (last_indices )) or  (max_seq  *  2  ==  s )
87- 
88-     # Combine indices 
89-     if  has_overlap :
90-         # If there's overlap, just use the combined unique indices 
91-         indices  =  sorted (list (set (first_indices  +  last_indices )))
92-         separator_index  =  None 
93-     else :
94-         # If no overlap, we'll add a separator between first and last sequences 
95-         indices  =  first_indices  +  last_indices 
96-         separator_index  =  len (first_indices )
97- 
98-     for  i , si  in  enumerate (indices ):
99-         # Add separator if needed 
100-         if  separator_index  is  not   None  and  i  ==  separator_index :
101-             print ("                                       ..." )
102- 
103-         # Extract appropriate slice 
104-         vec  =  t [0 , si ]
105-         if  vec .ndim  ==  2 :  # 4D case: flatten heads × dim_per_head 
106-             flat  =  vec .flatten ().tolist ()
107-         else :  # 2D or 3D case 
108-             flat  =  vec .tolist ()
109- 
110-         # First and last slices 
111-         first  =  flat [:max_vals ]
112-         last  =  flat [- max_vals :] if  len (flat ) >=  2  *  max_vals  else  flat 
113-         first_str  =  ", " .join (f"{ v :12.4f}  "  for  v  in  first )
114-         last_str  =  ", " .join (f"{ v :12.4f}  "  for  v  in  last )
115- 
116-         if  len (flat ) >=  2  *  max_vals :
117-             print (f"                                       [{ first_str }  , ..., { last_str }  ]" )
118-         else :
119-             print (f"                                       [{ last_str }  ]" )
120- 
121-     print ("                                      ]," )
122-     print ("                                     ]" )
123-     print (f"                                     sum = { t .sum ().item ():.6f} \n " )
56+     print (f"ggml_debug: { name }   = (f32)  ... = {{{ t .shape }  }}\n " )
57+     print (t )
58+     print (f"\n                                      sum = { t .sum ().item ():.6f} \n " )
12459
12560    indexed_patterns  =  [ r"model\.layers\.[0-9]+_out" , r"recurrent_cache_[0-9]+"  ]
12661    non_indexed_patterns  =  [ r"k_pad" , r"v_pad" , r"q_scaled"  ]
@@ -146,11 +81,41 @@ def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int =
14681                layer_counter [name ] =  layer_counter [name ] +  1   # skip attention layers 
14782        save_tensor (t , f"reference/tensors/org/{ name }  _{ layer_counter [name ]}  _{ token_counter [name ]}  .bin" )
14883
149- from  transformers .models .qwen3_next .modeling_qwen3_next  import  torch_causal_conv1d_update , apply_rotary_pos_emb , l2norm   # noqa: E402 
84+ from  transformers .models .qwen3_next .modeling_qwen3_next  import  torch_causal_conv1d_update , apply_rotary_pos_emb , l2norm , repeat_kv   # noqa: E402 
85+ from  transformers .processing_utils  import  Unpack  # noqa: E402 
86+ from  transformers .utils .generic  import  TransformersKwargs  # noqa: E402 
15087orig_conv1d_update  =  torch_causal_conv1d_update 
15188orig_rope  =  apply_rotary_pos_emb 
15289import  torch .nn .functional  as  F   # noqa: E402 
15390import  typing   # noqa: E402 
91+ from  torch  import  nn  # noqa: E402 
92+ 
93+ def  patched_eager_attention_forward (
94+     module : nn .Module ,
95+     query : torch .Tensor ,
96+     key : torch .Tensor ,
97+     value : torch .Tensor ,
98+     attention_mask : typing .Optional [torch .Tensor ],
99+     scaling : float ,
100+     dropout : float  =  0.0 ,
101+     ** kwargs : Unpack [TransformersKwargs ],
102+ ):
103+     print (f"\n Attention scaling: { scaling } \n " )
104+     key_states  =  repeat_kv (key , module .num_key_value_groups )
105+     value_states  =  repeat_kv (value , module .num_key_value_groups )
106+ 
107+     attn_weights  =  torch .matmul (query , key_states .transpose (2 , 3 )) *  scaling 
108+     if  attention_mask  is  not   None :
109+         causal_mask  =  attention_mask [:, :, :, : key_states .shape [- 2 ]]
110+         attn_weights  =  attn_weights  +  causal_mask 
111+ 
112+     attn_weights  =  nn .functional .softmax (attn_weights , dim = - 1 , dtype = torch .float32 ).to (query .dtype )
113+     attn_weights  =  nn .functional .dropout (attn_weights , p = dropout , training = module .training )
114+     attn_output  =  torch .matmul (attn_weights , value_states )
115+     attn_output  =  attn_output .transpose (1 , 2 ).contiguous ()
116+     summarize (attn_output , "attn_output" )
117+ 
118+     return  attn_output , attn_weights 
154119
155120def  patched_torch_causal_conv1d_update (
156121    hidden_states ,
@@ -343,7 +308,13 @@ def patched_torch_chunk_gated_delta_rule(
343308        summarize (k_i , f"k_i_chunk_{ i }  " )
344309        summarize (v_i , f"v_i_chunk_{ i }  " )
345310
346-         attn  =  (q_i  @ k_i .transpose (- 1 , - 2 ) *  decay_mask [:, :, i ]).masked_fill_ (mask , 0 )
311+         q_k_trans  =  q_i  @ k_i .transpose (- 1 , - 2 )
312+         summarize (q_k_trans , f"q_k_trans_{ i }  " )
313+ 
314+         q_k_trans_decay  =  q_k_trans  *  decay_mask [:, :, i ]
315+         summarize (q_k_trans_decay , f"q_k_trans_decay_{ i }  " )
316+ 
317+         attn  =  q_k_trans_decay .masked_fill_ (mask , 0 )
347318        summarize (attn , f"attn_chunk_{ i }  " )
348319
349320        v_prime  =  (k_cumdecay [:, :, i ]) @ last_recurrent_state 
@@ -362,16 +333,31 @@ def patched_torch_chunk_gated_delta_rule(
362333        summarize (g_last , f"g_last_chunk_{ i }  " )
363334
364335        g_diff_exp  =  (g [:, :, i , - 1 , None ] -  g [:, :, i ]).exp ()
365-         last_recurrent_state  =  (
366-             last_recurrent_state  *  g_last 
367-             +  (k_i  *  g_diff_exp [..., None ]).transpose (- 1 , - 2 ) @ v_new 
368-         )
336+         summarize (g_diff_exp , f"g_diff_exp_chunk_{ i }  " )
337+ 
338+         state_g_last  =  last_recurrent_state  *  g_last 
339+         summarize (state_g_last , f"state_g_last_{ i }  " )
340+ 
341+         k_g_diffexp  =  (k_i  *  g_diff_exp [..., None ])
342+         summarize (k_g_diffexp , f"k_g_diffexp_{ i }  " )
343+         
344+         k_g_diffexp_T  =  k_g_diffexp .transpose (- 1 , - 2 )
345+         summarize (k_g_diffexp , f"k_g_diffexp_T_{ i }  " )
346+ 
347+         kgd_mul_vnew  =  k_g_diffexp_T  @ v_new 
348+         summarize (kgd_mul_vnew , f"kgd_mul_vnew_{ i }  " )
349+ 
350+         last_recurrent_state  =  state_g_last  +  kgd_mul_vnew 
369351        summarize (last_recurrent_state , f"updated_state_chunk_{ i }  " )
370352
371353    if  not  output_final_state :
372354        last_recurrent_state  =  None 
373355    core_attn_out  =  core_attn_out .reshape (core_attn_out .shape [0 ], core_attn_out .shape [1 ], - 1 , core_attn_out .shape [- 1 ])
356+     summarize (core_attn_out , "attn_out_reshaped" )
357+ 
374358    core_attn_out  =  core_attn_out [:, :, :num_heads ]
359+     summarize (core_attn_out , "attn_out_truncated" )
360+ 
375361    core_attn_out  =  core_attn_out .transpose (1 , 2 ).contiguous ().to (initial_dtype )
376362    summarize (core_attn_out , "attn_out" )
377363
@@ -451,6 +437,7 @@ def patched_torch_recurrent_gated_delta_rule(
451437qwen_mod .torch_causal_conv1d_update  =  patched_torch_causal_conv1d_update 
452438qwen_mod .apply_rotary_pos_emb  =  patched_apply_rope 
453439qwen_mod .torch_recurrent_gated_delta_rule  =  patched_torch_recurrent_gated_delta_rule 
440+ qwen_mod .eager_attention_forward  =  patched_eager_attention_forward 
454441
455442# Store original functions for patching 
456443original_functions  =  {}
@@ -736,6 +723,8 @@ def save_kv_cache(past_key_values, step_num, data_dir, model_name):
736723all_generated_tokens  =  []
737724all_logits  =  []
738725
726+ model .config ._attn_implementation  =  "eager" 
727+ 
739728with  torch .no_grad ():
740729    # Initial forward pass 
741730    print (f"\n === Initial Forward Pass ===" )
0 commit comments