Skip to content

Commit 10032af

Browse files
committed
More debug data
1 parent d300ce9 commit 10032af

File tree

1 file changed

+245
-10
lines changed

1 file changed

+245
-10
lines changed

examples/model-conversion/scripts/causal/run-org-model-multi-token.py

Lines changed: 245 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -118,14 +118,15 @@ def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int =
118118
print(f" sum = {t.sum().item():.6f}\n")
119119

120120
pattern = r"model\.layers\.[0-9]+_out"
121-
if re.fullmatch(pattern, name):
121+
pattern2 = r"recurrent_cache_[0-9]+"
122+
if re.fullmatch(pattern, name) or re.fullmatch(pattern2, name):
122123
if name not in token_counter:
123124
token_counter[name] = 1
124125
else:
125126
token_counter[name] = token_counter[name] + 1
126127
save_tensor(t, f"reference/tensors/org/{name}_{token_counter[name]}.bin")
127128

128-
from transformers.models.qwen3_next.modeling_qwen3_next import torch_causal_conv1d_update, apply_rotary_pos_emb # noqa: E402
129+
from transformers.models.qwen3_next.modeling_qwen3_next import torch_causal_conv1d_update, apply_rotary_pos_emb, l2norm # noqa: E402
129130
orig_conv1d_update = torch_causal_conv1d_update
130131
orig_rope = apply_rotary_pos_emb
131132
import torch.nn.functional as F # noqa: E402
@@ -189,17 +190,17 @@ def patched_apply_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
189190
summarize(k, "RoPE.k_in")
190191
summarize(cos, "cos")
191192
summarize(sin, "sin")
192-
if q.shape[1] == 2 and k.shape[1] == 1 and k.shape[2] == 1 and not already_dumped_rope:
193-
already_dumped_rope = True
194-
print("Dumping input tensors")
195-
save_tensor(q, "reference/tensors/testrope_q_in.bin")
196-
save_tensor(k, "reference/tensors/testrope_k_in.bin")
197-
save_tensor(cos, "reference/tensors/testrope_cos_in.bin")
198-
save_tensor(sin, "reference/tensors/testrope_sin_in.bin")
193+
# if q.shape[1] == 2 and k.shape[1] == 1 and k.shape[2] == 1 and not already_dumped_rope:
194+
# already_dumped_rope = True
195+
# print("Dumping input tensors")
196+
# save_tensor(q, "reference/tensors/testrope_q_in.bin")
197+
# save_tensor(k, "reference/tensors/testrope_k_in.bin")
198+
# save_tensor(cos, "reference/tensors/testrope_cos_in.bin")
199+
# save_tensor(sin, "reference/tensors/testrope_sin_in.bin")
199200

200201
if position_ids:
201202
summarize(position_ids, "position_ids")
202-
print(f"Rotary dim is {cos.unsqueeze(unsqueeze_dim).shape[-1]}")
203+
# print(f"Rotary dim is {cos.unsqueeze(unsqueeze_dim).shape[-1]}")
203204

204205
# call original
205206
q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
@@ -210,9 +211,231 @@ def patched_apply_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
210211

211212
return q_out, k_out
212213

214+
def patched_torch_chunk_gated_delta_rule(
215+
query,
216+
key,
217+
value,
218+
g,
219+
beta,
220+
chunk_size=64,
221+
initial_state=None,
222+
output_final_state=False,
223+
use_qk_l2norm_in_kernel=False,
224+
long=False
225+
):
226+
torch.set_printoptions(threshold=10_000_000, sci_mode=False, precision=10, linewidth=200)
227+
initial_dtype = query.dtype
228+
[ summarize(x, y) for (x, y) in ((query, "q_prenorm"), (key, "k_prenorm")) ]
229+
if use_qk_l2norm_in_kernel:
230+
query = l2norm(query, dim=-1, eps=1e-6)
231+
key = l2norm(key, dim=-1, eps=1e-6)
232+
[ summarize(x, y) for (x, y) in ((query, "q_orig"), (key, "k_orig"), (value, "v_orig"), (beta, "b_orig"), (g, "g_orig")) ]
233+
query, key, value, beta, g = [
234+
x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
235+
]
236+
[ summarize(x, y) for (x, y) in ((query, "q_tra"), (key, "k_tra"), (value, "v_tra"), (beta, "b_tra"), (g, "g_tra")) ]
237+
batch_size, sequence_length, num_heads, k_head_dim = key.shape
238+
print(f"batch_size = {batch_size}, seq_len = {sequence_length}, num_heads = {num_heads}, k_head_dim = {k_head_dim}")
239+
v_head_dim = value.shape[-1]
240+
pad_size = (chunk_size - num_heads % chunk_size) % chunk_size
241+
print(f"Pad size = {pad_size}, chunk_size = {chunk_size}")
242+
query = F.pad(query, (0, 0, 0, pad_size))
243+
key = F.pad(key, (0, 0, 0, pad_size))
244+
value = F.pad(value, (0, 0, 0, pad_size))
245+
beta = F.pad(beta, (0, pad_size))
246+
g = F.pad(g, (0, pad_size))
247+
[ summarize(x, y) for (x, y) in ((query, "q_pad"), (key, "k_pad"), (value, "v_pad"), (beta, "b_pad"), (g, "g_pad")) ]
248+
tot_heads = num_heads + pad_size
249+
scale = 1 / (query.shape[-1] ** 0.5)
250+
print(f"Scale for delta is {scale} (from {query.shape[-1]})")
251+
query = query * scale
252+
253+
summarize(query, "q_scaled")
254+
summarize(key, "k")
255+
summarize(beta.unsqueeze(-1), "beta")
256+
v_beta = value * beta.unsqueeze(-1)
257+
k_beta = key * beta.unsqueeze(-1)
258+
summarize(k_beta, "k_beta")
259+
summarize(v_beta, "v_beta")
260+
# reshape to chunks
261+
query, key, value, k_beta, v_beta = [
262+
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta)
263+
]
264+
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
265+
[ summarize(x, y) for (x, y) in ((query, "q_resh"), (k_beta, "k_beta_resh"), (v_beta, "v_beta_resh"), (key, "k_resh"), (value, "v_resh")) ]
266+
267+
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
268+
269+
# chunk decay
270+
g = g.cumsum(dim=-1)
271+
summarize(g, "g_cumsum")
272+
sub = g.unsqueeze(-1) - g.unsqueeze(-2)
273+
bt1, bt2 = torch.broadcast_tensors(g.unsqueeze(-1), g.unsqueeze(-2))
274+
summarize(bt1, "bt1")
275+
summarize(bt2, "bt2")
276+
summarize(sub, "sub")
277+
decay_mask = sub.tril()
278+
summarize(decay_mask, "sub_tril")
279+
decay_mask = decay_mask.exp()
280+
summarize(decay_mask, "sub_tril_exp")
281+
decay_mask = decay_mask.float()
282+
summarize(decay_mask, "sub_tril_exp_float")
283+
decay_mask = decay_mask.tril()
284+
summarize(decay_mask, "decay_mask")
285+
k_t = key.transpose(-1, -2)
286+
summarize(k_t, "k_t")
287+
kmul = k_beta @ k_t
288+
summarize(kmul, "k_beta @ k_t")
289+
#if not long:
290+
#print(f"k_beta @ k_t:\n{kmul[:,:,:,:8,:8]}\n\n")
291+
kmul_decay = kmul * decay_mask
292+
summarize(kmul_decay, "(k_beta @ k_t) * decay_mask")
293+
attn = -(kmul_decay).masked_fill(mask, 0)
294+
summarize(attn, "attn_in")
295+
for i in range(1, chunk_size):
296+
row = attn[..., i, :i].clone()
297+
sub = attn[..., :i, :i].clone()
298+
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
299+
#if i <= num_heads and not long:
300+
#print(f"Chunk {i}: row:\n{row}\n\nsub:\n{sub}\nrow_unsq:\n{row.unsqueeze(-1)}\nrow_unsq * sub:\n{row.unsqueeze(-1)*sub}\n")
301+
#print(f"attn => sum = {attn[..., i, :i].sum()}, tensor: \n{attn[..., i, :i]}\n\n")
302+
summarize(attn, "attn_chunks")
303+
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
304+
summarize(attn, "attn_eye")
305+
306+
value = attn @ v_beta
307+
summarize(value, "value")
308+
309+
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
310+
summarize(k_cumdecay, "k_cumdecay")
311+
312+
last_recurrent_state = (
313+
torch.zeros(batch_size, sequence_length, k_head_dim, v_head_dim).to(value)
314+
if initial_state is None
315+
else initial_state.to(value)
316+
)
317+
core_attn_out = torch.zeros_like(value)
318+
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
319+
320+
# for each chunk
321+
for i in range(0, tot_heads // chunk_size):
322+
print(f"\n=== Processing chunk {i} ===")
323+
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
324+
summarize(q_i, f"q_i_chunk_{i}")
325+
summarize(k_i, f"k_i_chunk_{i}")
326+
summarize(v_i, f"v_i_chunk_{i}")
327+
328+
attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
329+
summarize(attn, f"attn_chunk_{i}")
330+
331+
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
332+
summarize(v_prime, f"v_prime_chunk_{i}")
333+
334+
v_new = v_i - v_prime
335+
summarize(v_new, f"v_new_chunk_{i}")
336+
337+
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
338+
summarize(attn_inter, f"attn_inter_chunk_{i}")
339+
340+
core_attn_out[:, :, i] = attn_inter + attn @ v_new
341+
summarize(core_attn_out[:, :, i], f"core_attn_out_chunk_{i}")
342+
343+
g_last = g[:, :, i, -1, None, None].exp()
344+
summarize(g_last, f"g_last_chunk_{i}")
345+
346+
g_diff_exp = (g[:, :, i, -1, None] - g[:, :, i]).exp()
347+
last_recurrent_state = (
348+
last_recurrent_state * g_last
349+
+ (k_i * g_diff_exp[..., None]).transpose(-1, -2) @ v_new
350+
)
351+
summarize(last_recurrent_state, f"updated_state_chunk_{i}")
352+
353+
if not output_final_state:
354+
last_recurrent_state = None
355+
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
356+
core_attn_out = core_attn_out[:, :, :num_heads]
357+
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
358+
summarize(core_attn_out, "attn_out")
359+
if not long:
360+
print(f"attn_out:\n{core_attn_out}\n\n")
361+
362+
if isinstance(last_recurrent_state, torch.Tensor):
363+
summarize(last_recurrent_state, "state_out")
364+
if not long:
365+
print(f"state_out:\n{last_recurrent_state}\n\n")
366+
return core_attn_out, last_recurrent_state
367+
368+
369+
def patched_torch_recurrent_gated_delta_rule(
370+
query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False
371+
):
372+
initial_dtype = query.dtype
373+
if use_qk_l2norm_in_kernel:
374+
query = l2norm(query, dim=-1, eps=1e-6)
375+
key = l2norm(key, dim=-1, eps=1e-6)
376+
query, key, value, beta, g = [
377+
x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
378+
]
379+
summarize(query, "q_t")
380+
summarize(key, "k_t")
381+
summarize(value, "v_t")
382+
summarize(beta, "beta_t")
383+
summarize(g, "g_t")
384+
385+
batch_size, num_heads, sequence_length, k_head_dim = key.shape
386+
v_head_dim = value.shape[-1]
387+
scale = 1 / (query.shape[-1] ** 0.5)
388+
query = query * scale
389+
390+
summarize(query, "q_scaled")
391+
if initial_state is not None:
392+
summarize(initial_state, "initial_state")
393+
394+
core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to(value)
395+
last_recurrent_state = (
396+
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
397+
if initial_state is None
398+
else initial_state.to(value)
399+
)
400+
401+
for i in range(sequence_length):
402+
q_t = query[:, :, i]
403+
k_t = key[:, :, i]
404+
v_t = value[:, :, i]
405+
g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1)
406+
summarize(g_t, "g_exp_unsq")
407+
beta_t = beta[:, :, i].unsqueeze(-1)
408+
summarize(beta_t, "beta_t_unsq")
409+
410+
last_recurrent_state = last_recurrent_state * g_t
411+
summarize(last_recurrent_state, "gated_state")
412+
k_unsq = k_t.unsqueeze(-1)
413+
summarize(k_unsq, "k_unsqueeze")
414+
state_k = last_recurrent_state * k_unsq
415+
summarize(state_k, "state_k_product")
416+
kv_mem = state_k.sum(dim=-2)
417+
summarize(kv_mem, "kv_mem")
418+
delta = (v_t - kv_mem) * beta_t
419+
summarize(delta, "delta")
420+
k_delta = k_t.unsqueeze(-1) * delta.unsqueeze(-2)
421+
summarize(k_delta, "k_delta")
422+
last_recurrent_state = last_recurrent_state + k_delta
423+
summarize(last_recurrent_state, "state_plus_k_delta")
424+
state_q_prod = last_recurrent_state * q_t.unsqueeze(-1)
425+
summarize(state_q_prod, "state_q_product")
426+
core_attn_out[:, :, i] = state_q_prod.sum(dim=-2)
427+
summarize(core_attn_out, "core_attn_out")
428+
429+
if not output_final_state:
430+
last_recurrent_state = None
431+
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
432+
return core_attn_out, last_recurrent_state
433+
213434
import transformers.models.qwen3_next.modeling_qwen3_next as qwen_mod # noqa: E402
435+
qwen_mod.torch_chunk_gated_delta_rule = patched_torch_chunk_gated_delta_rule
214436
qwen_mod.torch_causal_conv1d_update = patched_torch_causal_conv1d_update
215437
qwen_mod.apply_rotary_pos_emb = patched_apply_rope
438+
qwen_mod.torch_recurrent_gated_delta_rule = patched_torch_recurrent_gated_delta_rule
216439

217440
# Store original functions for patching
218441
original_functions = {}
@@ -259,6 +482,18 @@ def patched_forward(*args, **kwargs):
259482
# Call original forward
260483
result = orig_forward(*args, **kwargs)
261484

485+
if mod_name.endswith("linear_attn"):
486+
cache = kwargs["cache_params"]
487+
nameparts = mod_name.split(".")
488+
layer_idx = -1
489+
try:
490+
layer_idx = int(nameparts[2])
491+
except (ValueError, IndexError):
492+
print(f"\n\nDEBUG: Failed to calculate layer index for module: {mod_name}\n\n")
493+
rec_cache = cache.recurrent_states[layer_idx]
494+
if rec_cache is not None:
495+
summarize(rec_cache, f"recurrent_cache_{layer_idx}")
496+
262497
# Log output
263498
if isinstance(result, torch.Tensor):
264499
summarize(result, f"{mod_name}.forward.out")

0 commit comments

Comments
 (0)