@@ -69,7 +69,7 @@ def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int =
6969 t = t .unsqueeze (0 )
7070 elif t .ndim == 4 :
7171 _ , s , _ , _ = t .shape
72-
72+
7373 else :
7474 print (f"Skipping tensor due to unsupported dimensions: { t .ndim } " )
7575 return
@@ -124,7 +124,6 @@ def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int =
124124
125125 indexed_patterns = [ r"model\.layers\.[0-9]+_out" , r"recurrent_cache_[0-9]+" ]
126126 non_indexed_patterns = [ r"k_pad" , r"v_pad" , r"q_scaled" ]
127-
128127 if any (re .fullmatch (p , name ) for p in indexed_patterns ):
129128 if name not in token_counter :
130129 token_counter [name ] = 1
@@ -135,13 +134,17 @@ def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int =
135134 if any (re .fullmatch (p , name ) for p in non_indexed_patterns ):
136135 if name not in token_counter :
137136 token_counter [name ] = 1
138- else :
139- token_counter [name ] = token_counter [name ] + 1
140- if name not in layer_counter or layer_counter [name ] == num_model_layers - 1 :
137+
138+ if name not in layer_counter :
139+ layer_counter [name ] = 0
140+ elif layer_counter [name ] >= num_model_layers :
141141 layer_counter [name ] = 0
142+ token_counter [name ] = token_counter [name ] + 1
142143 else :
143144 layer_counter [name ] = layer_counter [name ] + 1
144- save_tensor (t , f"reference/tensors/org/{ name } _{ layer_counter [name ] - 1 } _{ token_counter [name ]} .bin" )
145+ if layer_counter [name ] % 4 == 3 :
146+ layer_counter [name ] = layer_counter [name ] + 1 # skip attention layers
147+ save_tensor (t , f"reference/tensors/org/{ name } _{ layer_counter [name ]} _{ token_counter [name ]} .bin" )
145148
146149from transformers .models .qwen3_next .modeling_qwen3_next import torch_causal_conv1d_update , apply_rotary_pos_emb , l2norm # noqa: E402
147150orig_conv1d_update = torch_causal_conv1d_update
@@ -181,20 +184,20 @@ def save_tensor(tensor, filename):
181184 """Save tensor to binary file with shape information."""
182185 # Ensure tensors directory exists
183186 os .makedirs (os .path .dirname (filename ), exist_ok = True )
184-
187+
185188 # Convert to numpy and save
186189 np_array = tensor .detach ().cpu ().numpy ()
187-
190+
188191 # Save shape first (4 int64 values), then data
189192 with open (filename , 'wb' ) as f :
190193 shape = list (np_array .shape )
191194 while len (shape ) < 4 :
192195 shape .insert (0 , 0 )
193-
196+
194197 # Write shape as int64
195198 shape_array = np .array (shape , dtype = np .int64 )
196199 f .write (shape_array .tobytes ())
197-
200+
198201 # Write data as float32
199202 np_array_float32 = np_array .astype (np .float32 )
200203 f .write (np_array_float32 .tobytes ())
@@ -311,19 +314,19 @@ def patched_torch_chunk_gated_delta_rule(
311314 row = attn [..., i , :i ].clone ()
312315 sub = attn [..., :i , :i ].clone ()
313316 attn [..., i , :i ] = row + (row .unsqueeze (- 1 ) * sub ).sum (- 2 )
314- #if i <= num_heads and not long:
317+ #if i <= num_heads and not long:
315318 #print(f"Chunk {i}: row:\n{row}\n\nsub:\n{sub}\nrow_unsq:\n{row.unsqueeze(-1)}\nrow_unsq * sub:\n{row.unsqueeze(-1)*sub}\n")
316319 #print(f"attn => sum = {attn[..., i, :i].sum()}, tensor: \n{attn[..., i, :i]}\n\n")
317320 summarize (attn , "attn_chunks" )
318321 attn = attn + torch .eye (chunk_size , dtype = attn .dtype , device = attn .device )
319322 summarize (attn , "attn_eye" )
320-
323+
321324 value = attn @ v_beta
322325 summarize (value , "value" )
323-
326+
324327 k_cumdecay = attn @ (k_beta * g .exp ().unsqueeze (- 1 ))
325328 summarize (k_cumdecay , "k_cumdecay" )
326-
329+
327330 last_recurrent_state = (
328331 torch .zeros (batch_size , sequence_length , k_head_dim , v_head_dim ).to (value )
329332 if initial_state is None
@@ -339,25 +342,25 @@ def patched_torch_chunk_gated_delta_rule(
339342 summarize (q_i , f"q_i_chunk_{ i } " )
340343 summarize (k_i , f"k_i_chunk_{ i } " )
341344 summarize (v_i , f"v_i_chunk_{ i } " )
342-
345+
343346 attn = (q_i @ k_i .transpose (- 1 , - 2 ) * decay_mask [:, :, i ]).masked_fill_ (mask , 0 )
344347 summarize (attn , f"attn_chunk_{ i } " )
345-
348+
346349 v_prime = (k_cumdecay [:, :, i ]) @ last_recurrent_state
347350 summarize (v_prime , f"v_prime_chunk_{ i } " )
348-
351+
349352 v_new = v_i - v_prime
350353 summarize (v_new , f"v_new_chunk_{ i } " )
351-
354+
352355 attn_inter = (q_i * g [:, :, i , :, None ].exp ()) @ last_recurrent_state
353356 summarize (attn_inter , f"attn_inter_chunk_{ i } " )
354-
357+
355358 core_attn_out [:, :, i ] = attn_inter + attn @ v_new
356359 summarize (core_attn_out [:, :, i ], f"core_attn_out_chunk_{ i } " )
357-
360+
358361 g_last = g [:, :, i , - 1 , None , None ].exp ()
359362 summarize (g_last , f"g_last_chunk_{ i } " )
360-
363+
361364 g_diff_exp = (g [:, :, i , - 1 , None ] - g [:, :, i ]).exp ()
362365 last_recurrent_state = (
363366 last_recurrent_state * g_last
@@ -371,7 +374,7 @@ def patched_torch_chunk_gated_delta_rule(
371374 core_attn_out = core_attn_out [:, :, :num_heads ]
372375 core_attn_out = core_attn_out .transpose (1 , 2 ).contiguous ().to (initial_dtype )
373376 summarize (core_attn_out , "attn_out" )
374-
377+
375378 if isinstance (last_recurrent_state , torch .Tensor ):
376379 summarize (last_recurrent_state , "state_out" )
377380
@@ -615,28 +618,28 @@ def save_kv_cache(past_key_values, step_num, data_dir, model_name):
615618 """Save KV cache tensors for each layer"""
616619 cache_dir = data_dir / f"kv_cache_step_{ step_num } "
617620 cache_dir .mkdir (exist_ok = True )
618-
621+
619622 # Access past_key_values if available
620623 if past_key_values is not None :
621624 for layer_idx , cache_tuple in enumerate (past_key_values ):
622625 if cache_tuple is None :
623626 print (f"Cache tuple is None for layer { layer_idx } at step { step_num } " )
624627 continue
625-
628+
626629 # Handle different cache formats
627630 if isinstance (cache_tuple , (tuple , list )) and len (cache_tuple ) >= 2 :
628631 key , value = cache_tuple [0 ], cache_tuple [1 ]
629-
632+
630633 # Check if key and value are not None
631634 if key is not None and value is not None :
632635 # Save key cache
633636 key_filename = cache_dir / f"layer_{ layer_idx } _key.bin"
634637 key .detach ().cpu ().numpy ().astype (np .float32 ).tofile (key_filename )
635-
638+
636639 # Save value cache
637640 value_filename = cache_dir / f"layer_{ layer_idx } _value.bin"
638641 value .detach ().cpu ().numpy ().astype (np .float32 ).tofile (value_filename )
639-
642+
640643 print (f"Saved KV cache for layer { layer_idx } at step { step_num } : key.shape={ key .shape } , value.shape={ value .shape } " )
641644 else :
642645 print (f"Key or value is None for layer { layer_idx } at step { step_num } " )
@@ -738,67 +741,67 @@ def save_kv_cache(past_key_values, step_num, data_dir, model_name):
738741 print (f"\n === Initial Forward Pass ===" )
739742 outputs = model (input_ids , use_cache = True )
740743 logits = outputs .logits
741-
744+
742745 # Extract logits for the last token (next token prediction)
743746 last_logits = logits [0 , - 1 , :].cpu ().numpy ()
744747 all_logits .append (last_logits )
745-
748+
746749 print (f"Logits shape: { logits .shape } " )
747750 print (f"Last token logits shape: { last_logits .shape } " )
748-
751+
749752 # Generate first token
750753 next_token_id = np .argmax (last_logits ).item ()
751754 all_generated_tokens .append (next_token_id )
752-
755+
753756 # Show top 5 predicted tokens for first step
754757 top_indices = np .argsort (last_logits )[- 5 :][::- 1 ]
755758 print ("Top 5 predictions for first token:" )
756759 for idx in top_indices :
757760 token = tokenizer .decode ([idx ])
758761 print (f" Token { idx } ({ repr (token )} ): { last_logits [idx ]:.6f} " )
759-
762+
760763 print (f"Generated token { next_token_id } ({ repr (tokenizer .decode ([next_token_id ]))} )" )
761-
764+
762765 # Save KV cache if requested
763766 if args .save_cache :
764767 save_kv_cache (outputs .past_key_values , 0 , data_dir , model_name )
765-
768+
766769 # Prepare for next iteration
767770 past_key_values = outputs .past_key_values
768771 current_input = torch .tensor ([[next_token_id ]], device = device )
769-
772+
770773 # Generate remaining tokens
771774 for step in range (1 , args .num_tokens ):
772775 print (f"\n === Generation Step { step } ===" )
773-
776+
774777 # Forward pass with cache
775778 outputs = model (
776- input_ids = current_input ,
779+ input_ids = current_input ,
777780 past_key_values = past_key_values ,
778781 use_cache = True
779782 )
780-
783+
781784 logits = outputs .logits
782785 last_logits = logits [0 , - 1 , :].cpu ().numpy ()
783786 all_logits .append (last_logits )
784-
787+
785788 # Generate next token
786789 next_token_id = np .argmax (last_logits ).item ()
787790 all_generated_tokens .append (next_token_id )
788-
791+
789792 # Show top 5 predicted tokens for this step
790793 top_indices = np .argsort (last_logits )[- 5 :][::- 1 ]
791794 print (f"Top 5 predictions for step { step } :" )
792795 for idx in top_indices :
793796 token = tokenizer .decode ([idx ])
794797 print (f" Token { idx } ({ repr (token )} ): { last_logits [idx ]:.6f} " )
795-
798+
796799 print (f"Generated token { next_token_id } ({ repr (tokenizer .decode ([next_token_id ]))} )" )
797-
800+
798801 # Save KV cache if requested
799802 if args .save_cache :
800803 save_kv_cache (outputs .past_key_values , step , data_dir , model_name )
801-
804+
802805 # Update for next iteration
803806 past_key_values = outputs .past_key_values
804807 current_input = torch .tensor ([[next_token_id ]], device = device )
@@ -816,7 +819,7 @@ def save_kv_cache(past_key_values, step_num, data_dir, model_name):
816819 f .write (f"Generated tokens: { all_generated_tokens } \n " )
817820 f .write (f"Generated text: { repr (tokenizer .decode (all_generated_tokens ))} \n " )
818821 f .write (f"Full sequence: { repr (tokenizer .decode (input_ids [0 ].tolist () + all_generated_tokens ))} \n \n " )
819-
822+
820823 for step , logits in enumerate (all_logits ):
821824 f .write (f"=== Step { step } logits ===\n " )
822825 for i , logit in enumerate (logits ):
@@ -832,4 +835,4 @@ def save_kv_cache(past_key_values, step_num, data_dir, model_name):
832835print (f"Saved txt logits to: { txt_filename } " )
833836
834837if args .save_cache :
835- print (f"KV cache saved to: { data_dir } /kv_cache_step_*" )
838+ print (f"KV cache saved to: { data_dir } /kv_cache_step_*" )
0 commit comments