@@ -118,14 +118,15 @@ def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int =
118118 print (f" sum = { t .sum ().item ():.6f} \n " )
119119
120120 pattern = r"model\.layers\.[0-9]+_out"
121- if re .fullmatch (pattern , name ):
121+ pattern2 = r"recurrent_cache_[0-9]+"
122+ if re .fullmatch (pattern , name ) or re .fullmatch (pattern2 , name ):
122123 if name not in token_counter :
123124 token_counter [name ] = 1
124125 else :
125126 token_counter [name ] = token_counter [name ] + 1
126127 save_tensor (t , f"reference/tensors/org/{ name } _{ token_counter [name ]} .bin" )
127128
128- from transformers .models .qwen3_next .modeling_qwen3_next import torch_causal_conv1d_update , apply_rotary_pos_emb # noqa: E402
129+ from transformers .models .qwen3_next .modeling_qwen3_next import torch_causal_conv1d_update , apply_rotary_pos_emb , l2norm # noqa: E402
129130orig_conv1d_update = torch_causal_conv1d_update
130131orig_rope = apply_rotary_pos_emb
131132import torch .nn .functional as F # noqa: E402
@@ -189,17 +190,17 @@ def patched_apply_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
189190 summarize (k , "RoPE.k_in" )
190191 summarize (cos , "cos" )
191192 summarize (sin , "sin" )
192- if q .shape [1 ] == 2 and k .shape [1 ] == 1 and k .shape [2 ] == 1 and not already_dumped_rope :
193- already_dumped_rope = True
194- print ("Dumping input tensors" )
195- save_tensor (q , "reference/tensors/testrope_q_in.bin" )
196- save_tensor (k , "reference/tensors/testrope_k_in.bin" )
197- save_tensor (cos , "reference/tensors/testrope_cos_in.bin" )
198- save_tensor (sin , "reference/tensors/testrope_sin_in.bin" )
193+ # if q.shape[1] == 2 and k.shape[1] == 1 and k.shape[2] == 1 and not already_dumped_rope:
194+ # already_dumped_rope = True
195+ # print("Dumping input tensors")
196+ # save_tensor(q, "reference/tensors/testrope_q_in.bin")
197+ # save_tensor(k, "reference/tensors/testrope_k_in.bin")
198+ # save_tensor(cos, "reference/tensors/testrope_cos_in.bin")
199+ # save_tensor(sin, "reference/tensors/testrope_sin_in.bin")
199200
200201 if position_ids :
201202 summarize (position_ids , "position_ids" )
202- print (f"Rotary dim is { cos .unsqueeze (unsqueeze_dim ).shape [- 1 ]} " )
203+ # print(f"Rotary dim is {cos.unsqueeze(unsqueeze_dim).shape[-1]}")
203204
204205 # call original
205206 q_out , k_out = orig_rope (q , k , cos , sin , position_ids , unsqueeze_dim )
@@ -210,9 +211,231 @@ def patched_apply_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
210211
211212 return q_out , k_out
212213
214+ def patched_torch_chunk_gated_delta_rule (
215+ query ,
216+ key ,
217+ value ,
218+ g ,
219+ beta ,
220+ chunk_size = 64 ,
221+ initial_state = None ,
222+ output_final_state = False ,
223+ use_qk_l2norm_in_kernel = False ,
224+ long = False
225+ ):
226+ torch .set_printoptions (threshold = 10_000_000 , sci_mode = False , precision = 10 , linewidth = 200 )
227+ initial_dtype = query .dtype
228+ [ summarize (x , y ) for (x , y ) in ((query , "q_prenorm" ), (key , "k_prenorm" )) ]
229+ if use_qk_l2norm_in_kernel :
230+ query = l2norm (query , dim = - 1 , eps = 1e-6 )
231+ key = l2norm (key , dim = - 1 , eps = 1e-6 )
232+ [ summarize (x , y ) for (x , y ) in ((query , "q_orig" ), (key , "k_orig" ), (value , "v_orig" ), (beta , "b_orig" ), (g , "g_orig" )) ]
233+ query , key , value , beta , g = [
234+ x .transpose (1 , 2 ).contiguous ().to (torch .float32 ) for x in (query , key , value , beta , g )
235+ ]
236+ [ summarize (x , y ) for (x , y ) in ((query , "q_tra" ), (key , "k_tra" ), (value , "v_tra" ), (beta , "b_tra" ), (g , "g_tra" )) ]
237+ batch_size , sequence_length , num_heads , k_head_dim = key .shape
238+ print (f"batch_size = { batch_size } , seq_len = { sequence_length } , num_heads = { num_heads } , k_head_dim = { k_head_dim } " )
239+ v_head_dim = value .shape [- 1 ]
240+ pad_size = (chunk_size - num_heads % chunk_size ) % chunk_size
241+ print (f"Pad size = { pad_size } , chunk_size = { chunk_size } " )
242+ query = F .pad (query , (0 , 0 , 0 , pad_size ))
243+ key = F .pad (key , (0 , 0 , 0 , pad_size ))
244+ value = F .pad (value , (0 , 0 , 0 , pad_size ))
245+ beta = F .pad (beta , (0 , pad_size ))
246+ g = F .pad (g , (0 , pad_size ))
247+ [ summarize (x , y ) for (x , y ) in ((query , "q_pad" ), (key , "k_pad" ), (value , "v_pad" ), (beta , "b_pad" ), (g , "g_pad" )) ]
248+ tot_heads = num_heads + pad_size
249+ scale = 1 / (query .shape [- 1 ] ** 0.5 )
250+ print (f"Scale for delta is { scale } (from { query .shape [- 1 ]} )" )
251+ query = query * scale
252+
253+ summarize (query , "q_scaled" )
254+ summarize (key , "k" )
255+ summarize (beta .unsqueeze (- 1 ), "beta" )
256+ v_beta = value * beta .unsqueeze (- 1 )
257+ k_beta = key * beta .unsqueeze (- 1 )
258+ summarize (k_beta , "k_beta" )
259+ summarize (v_beta , "v_beta" )
260+ # reshape to chunks
261+ query , key , value , k_beta , v_beta = [
262+ x .reshape (x .shape [0 ], x .shape [1 ], - 1 , chunk_size , x .shape [- 1 ]) for x in (query , key , value , k_beta , v_beta )
263+ ]
264+ g = g .reshape (g .shape [0 ], g .shape [1 ], - 1 , chunk_size )
265+ [ summarize (x , y ) for (x , y ) in ((query , "q_resh" ), (k_beta , "k_beta_resh" ), (v_beta , "v_beta_resh" ), (key , "k_resh" ), (value , "v_resh" )) ]
266+
267+ mask = torch .triu (torch .ones (chunk_size , chunk_size , dtype = torch .bool , device = query .device ), diagonal = 0 )
268+
269+ # chunk decay
270+ g = g .cumsum (dim = - 1 )
271+ summarize (g , "g_cumsum" )
272+ sub = g .unsqueeze (- 1 ) - g .unsqueeze (- 2 )
273+ bt1 , bt2 = torch .broadcast_tensors (g .unsqueeze (- 1 ), g .unsqueeze (- 2 ))
274+ summarize (bt1 , "bt1" )
275+ summarize (bt2 , "bt2" )
276+ summarize (sub , "sub" )
277+ decay_mask = sub .tril ()
278+ summarize (decay_mask , "sub_tril" )
279+ decay_mask = decay_mask .exp ()
280+ summarize (decay_mask , "sub_tril_exp" )
281+ decay_mask = decay_mask .float ()
282+ summarize (decay_mask , "sub_tril_exp_float" )
283+ decay_mask = decay_mask .tril ()
284+ summarize (decay_mask , "decay_mask" )
285+ k_t = key .transpose (- 1 , - 2 )
286+ summarize (k_t , "k_t" )
287+ kmul = k_beta @ k_t
288+ summarize (kmul , "k_beta @ k_t" )
289+ #if not long:
290+ #print(f"k_beta @ k_t:\n{kmul[:,:,:,:8,:8]}\n\n")
291+ kmul_decay = kmul * decay_mask
292+ summarize (kmul_decay , "(k_beta @ k_t) * decay_mask" )
293+ attn = - (kmul_decay ).masked_fill (mask , 0 )
294+ summarize (attn , "attn_in" )
295+ for i in range (1 , chunk_size ):
296+ row = attn [..., i , :i ].clone ()
297+ sub = attn [..., :i , :i ].clone ()
298+ attn [..., i , :i ] = row + (row .unsqueeze (- 1 ) * sub ).sum (- 2 )
299+ #if i <= num_heads and not long:
300+ #print(f"Chunk {i}: row:\n{row}\n\nsub:\n{sub}\nrow_unsq:\n{row.unsqueeze(-1)}\nrow_unsq * sub:\n{row.unsqueeze(-1)*sub}\n")
301+ #print(f"attn => sum = {attn[..., i, :i].sum()}, tensor: \n{attn[..., i, :i]}\n\n")
302+ summarize (attn , "attn_chunks" )
303+ attn = attn + torch .eye (chunk_size , dtype = attn .dtype , device = attn .device )
304+ summarize (attn , "attn_eye" )
305+
306+ value = attn @ v_beta
307+ summarize (value , "value" )
308+
309+ k_cumdecay = attn @ (k_beta * g .exp ().unsqueeze (- 1 ))
310+ summarize (k_cumdecay , "k_cumdecay" )
311+
312+ last_recurrent_state = (
313+ torch .zeros (batch_size , sequence_length , k_head_dim , v_head_dim ).to (value )
314+ if initial_state is None
315+ else initial_state .to (value )
316+ )
317+ core_attn_out = torch .zeros_like (value )
318+ mask = torch .triu (torch .ones (chunk_size , chunk_size , dtype = torch .bool , device = query .device ), diagonal = 1 )
319+
320+ # for each chunk
321+ for i in range (0 , tot_heads // chunk_size ):
322+ print (f"\n === Processing chunk { i } ===" )
323+ q_i , k_i , v_i = query [:, :, i ], key [:, :, i ], value [:, :, i ]
324+ summarize (q_i , f"q_i_chunk_{ i } " )
325+ summarize (k_i , f"k_i_chunk_{ i } " )
326+ summarize (v_i , f"v_i_chunk_{ i } " )
327+
328+ attn = (q_i @ k_i .transpose (- 1 , - 2 ) * decay_mask [:, :, i ]).masked_fill_ (mask , 0 )
329+ summarize (attn , f"attn_chunk_{ i } " )
330+
331+ v_prime = (k_cumdecay [:, :, i ]) @ last_recurrent_state
332+ summarize (v_prime , f"v_prime_chunk_{ i } " )
333+
334+ v_new = v_i - v_prime
335+ summarize (v_new , f"v_new_chunk_{ i } " )
336+
337+ attn_inter = (q_i * g [:, :, i , :, None ].exp ()) @ last_recurrent_state
338+ summarize (attn_inter , f"attn_inter_chunk_{ i } " )
339+
340+ core_attn_out [:, :, i ] = attn_inter + attn @ v_new
341+ summarize (core_attn_out [:, :, i ], f"core_attn_out_chunk_{ i } " )
342+
343+ g_last = g [:, :, i , - 1 , None , None ].exp ()
344+ summarize (g_last , f"g_last_chunk_{ i } " )
345+
346+ g_diff_exp = (g [:, :, i , - 1 , None ] - g [:, :, i ]).exp ()
347+ last_recurrent_state = (
348+ last_recurrent_state * g_last
349+ + (k_i * g_diff_exp [..., None ]).transpose (- 1 , - 2 ) @ v_new
350+ )
351+ summarize (last_recurrent_state , f"updated_state_chunk_{ i } " )
352+
353+ if not output_final_state :
354+ last_recurrent_state = None
355+ core_attn_out = core_attn_out .reshape (core_attn_out .shape [0 ], core_attn_out .shape [1 ], - 1 , core_attn_out .shape [- 1 ])
356+ core_attn_out = core_attn_out [:, :, :num_heads ]
357+ core_attn_out = core_attn_out .transpose (1 , 2 ).contiguous ().to (initial_dtype )
358+ summarize (core_attn_out , "attn_out" )
359+ if not long :
360+ print (f"attn_out:\n { core_attn_out } \n \n " )
361+
362+ if isinstance (last_recurrent_state , torch .Tensor ):
363+ summarize (last_recurrent_state , "state_out" )
364+ if not long :
365+ print (f"state_out:\n { last_recurrent_state } \n \n " )
366+ return core_attn_out , last_recurrent_state
367+
368+
369+ def patched_torch_recurrent_gated_delta_rule (
370+ query , key , value , g , beta , initial_state , output_final_state , use_qk_l2norm_in_kernel = False
371+ ):
372+ initial_dtype = query .dtype
373+ if use_qk_l2norm_in_kernel :
374+ query = l2norm (query , dim = - 1 , eps = 1e-6 )
375+ key = l2norm (key , dim = - 1 , eps = 1e-6 )
376+ query , key , value , beta , g = [
377+ x .transpose (1 , 2 ).contiguous ().to (torch .float32 ) for x in (query , key , value , beta , g )
378+ ]
379+ summarize (query , "q_t" )
380+ summarize (key , "k_t" )
381+ summarize (value , "v_t" )
382+ summarize (beta , "beta_t" )
383+ summarize (g , "g_t" )
384+
385+ batch_size , num_heads , sequence_length , k_head_dim = key .shape
386+ v_head_dim = value .shape [- 1 ]
387+ scale = 1 / (query .shape [- 1 ] ** 0.5 )
388+ query = query * scale
389+
390+ summarize (query , "q_scaled" )
391+ if initial_state is not None :
392+ summarize (initial_state , "initial_state" )
393+
394+ core_attn_out = torch .zeros (batch_size , num_heads , sequence_length , v_head_dim ).to (value )
395+ last_recurrent_state = (
396+ torch .zeros (batch_size , num_heads , k_head_dim , v_head_dim ).to (value )
397+ if initial_state is None
398+ else initial_state .to (value )
399+ )
400+
401+ for i in range (sequence_length ):
402+ q_t = query [:, :, i ]
403+ k_t = key [:, :, i ]
404+ v_t = value [:, :, i ]
405+ g_t = g [:, :, i ].exp ().unsqueeze (- 1 ).unsqueeze (- 1 )
406+ summarize (g_t , "g_exp_unsq" )
407+ beta_t = beta [:, :, i ].unsqueeze (- 1 )
408+ summarize (beta_t , "beta_t_unsq" )
409+
410+ last_recurrent_state = last_recurrent_state * g_t
411+ summarize (last_recurrent_state , "gated_state" )
412+ k_unsq = k_t .unsqueeze (- 1 )
413+ summarize (k_unsq , "k_unsqueeze" )
414+ state_k = last_recurrent_state * k_unsq
415+ summarize (state_k , "state_k_product" )
416+ kv_mem = state_k .sum (dim = - 2 )
417+ summarize (kv_mem , "kv_mem" )
418+ delta = (v_t - kv_mem ) * beta_t
419+ summarize (delta , "delta" )
420+ k_delta = k_t .unsqueeze (- 1 ) * delta .unsqueeze (- 2 )
421+ summarize (k_delta , "k_delta" )
422+ last_recurrent_state = last_recurrent_state + k_delta
423+ summarize (last_recurrent_state , "state_plus_k_delta" )
424+ state_q_prod = last_recurrent_state * q_t .unsqueeze (- 1 )
425+ summarize (state_q_prod , "state_q_product" )
426+ core_attn_out [:, :, i ] = state_q_prod .sum (dim = - 2 )
427+ summarize (core_attn_out , "core_attn_out" )
428+
429+ if not output_final_state :
430+ last_recurrent_state = None
431+ core_attn_out = core_attn_out .transpose (1 , 2 ).contiguous ().to (initial_dtype )
432+ return core_attn_out , last_recurrent_state
433+
213434import transformers .models .qwen3_next .modeling_qwen3_next as qwen_mod # noqa: E402
435+ qwen_mod .torch_chunk_gated_delta_rule = patched_torch_chunk_gated_delta_rule
214436qwen_mod .torch_causal_conv1d_update = patched_torch_causal_conv1d_update
215437qwen_mod .apply_rotary_pos_emb = patched_apply_rope
438+ qwen_mod .torch_recurrent_gated_delta_rule = patched_torch_recurrent_gated_delta_rule
216439
217440# Store original functions for patching
218441original_functions = {}
@@ -259,6 +482,18 @@ def patched_forward(*args, **kwargs):
259482 # Call original forward
260483 result = orig_forward (* args , ** kwargs )
261484
485+ if mod_name .endswith ("linear_attn" ):
486+ cache = kwargs ["cache_params" ]
487+ nameparts = mod_name .split ("." )
488+ layer_idx = - 1
489+ try :
490+ layer_idx = int (nameparts [2 ])
491+ except (ValueError , IndexError ):
492+ print (f"\n \n DEBUG: Failed to calculate layer index for module: { mod_name } \n \n " )
493+ rec_cache = cache .recurrent_states [layer_idx ]
494+ if rec_cache is not None :
495+ summarize (rec_cache , f"recurrent_cache_{ layer_idx } " )
496+
262497 # Log output
263498 if isinstance (result , torch .Tensor ):
264499 summarize (result , f"{ mod_name } .forward.out" )
0 commit comments