Skip to content

Commit a4fe128

Browse files
committed
Fix layer counting logic
1 parent 610b0fe commit a4fe128

File tree

1 file changed

+48
-45
lines changed

1 file changed

+48
-45
lines changed

examples/model-conversion/scripts/causal/run-org-model-multi-token.py

Lines changed: 48 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int =
6969
t = t.unsqueeze(0)
7070
elif t.ndim == 4:
7171
_, s, _, _ = t.shape
72-
72+
7373
else:
7474
print(f"Skipping tensor due to unsupported dimensions: {t.ndim}")
7575
return
@@ -124,7 +124,6 @@ def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int =
124124

125125
indexed_patterns = [ r"model\.layers\.[0-9]+_out", r"recurrent_cache_[0-9]+" ]
126126
non_indexed_patterns = [ r"k_pad", r"v_pad", r"q_scaled" ]
127-
128127
if any(re.fullmatch(p, name) for p in indexed_patterns):
129128
if name not in token_counter:
130129
token_counter[name] = 1
@@ -135,13 +134,17 @@ def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int =
135134
if any(re.fullmatch(p, name) for p in non_indexed_patterns):
136135
if name not in token_counter:
137136
token_counter[name] = 1
138-
else:
139-
token_counter[name] = token_counter[name] + 1
140-
if name not in layer_counter or layer_counter[name] == num_model_layers - 1:
137+
138+
if name not in layer_counter:
139+
layer_counter[name] = 0
140+
elif layer_counter[name] >= num_model_layers:
141141
layer_counter[name] = 0
142+
token_counter[name] = token_counter[name] + 1
142143
else:
143144
layer_counter[name] = layer_counter[name] + 1
144-
save_tensor(t, f"reference/tensors/org/{name}_{layer_counter[name] - 1}_{token_counter[name]}.bin")
145+
if layer_counter[name] % 4 == 3:
146+
layer_counter[name] = layer_counter[name] + 1 # skip attention layers
147+
save_tensor(t, f"reference/tensors/org/{name}_{layer_counter[name]}_{token_counter[name]}.bin")
145148

146149
from transformers.models.qwen3_next.modeling_qwen3_next import torch_causal_conv1d_update, apply_rotary_pos_emb, l2norm # noqa: E402
147150
orig_conv1d_update = torch_causal_conv1d_update
@@ -181,20 +184,20 @@ def save_tensor(tensor, filename):
181184
"""Save tensor to binary file with shape information."""
182185
# Ensure tensors directory exists
183186
os.makedirs(os.path.dirname(filename), exist_ok=True)
184-
187+
185188
# Convert to numpy and save
186189
np_array = tensor.detach().cpu().numpy()
187-
190+
188191
# Save shape first (4 int64 values), then data
189192
with open(filename, 'wb') as f:
190193
shape = list(np_array.shape)
191194
while len(shape) < 4:
192195
shape.insert(0, 0)
193-
196+
194197
# Write shape as int64
195198
shape_array = np.array(shape, dtype=np.int64)
196199
f.write(shape_array.tobytes())
197-
200+
198201
# Write data as float32
199202
np_array_float32 = np_array.astype(np.float32)
200203
f.write(np_array_float32.tobytes())
@@ -311,19 +314,19 @@ def patched_torch_chunk_gated_delta_rule(
311314
row = attn[..., i, :i].clone()
312315
sub = attn[..., :i, :i].clone()
313316
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
314-
#if i <= num_heads and not long:
317+
#if i <= num_heads and not long:
315318
#print(f"Chunk {i}: row:\n{row}\n\nsub:\n{sub}\nrow_unsq:\n{row.unsqueeze(-1)}\nrow_unsq * sub:\n{row.unsqueeze(-1)*sub}\n")
316319
#print(f"attn => sum = {attn[..., i, :i].sum()}, tensor: \n{attn[..., i, :i]}\n\n")
317320
summarize(attn, "attn_chunks")
318321
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
319322
summarize(attn, "attn_eye")
320-
323+
321324
value = attn @ v_beta
322325
summarize(value, "value")
323-
326+
324327
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
325328
summarize(k_cumdecay, "k_cumdecay")
326-
329+
327330
last_recurrent_state = (
328331
torch.zeros(batch_size, sequence_length, k_head_dim, v_head_dim).to(value)
329332
if initial_state is None
@@ -339,25 +342,25 @@ def patched_torch_chunk_gated_delta_rule(
339342
summarize(q_i, f"q_i_chunk_{i}")
340343
summarize(k_i, f"k_i_chunk_{i}")
341344
summarize(v_i, f"v_i_chunk_{i}")
342-
345+
343346
attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
344347
summarize(attn, f"attn_chunk_{i}")
345-
348+
346349
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
347350
summarize(v_prime, f"v_prime_chunk_{i}")
348-
351+
349352
v_new = v_i - v_prime
350353
summarize(v_new, f"v_new_chunk_{i}")
351-
354+
352355
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
353356
summarize(attn_inter, f"attn_inter_chunk_{i}")
354-
357+
355358
core_attn_out[:, :, i] = attn_inter + attn @ v_new
356359
summarize(core_attn_out[:, :, i], f"core_attn_out_chunk_{i}")
357-
360+
358361
g_last = g[:, :, i, -1, None, None].exp()
359362
summarize(g_last, f"g_last_chunk_{i}")
360-
363+
361364
g_diff_exp = (g[:, :, i, -1, None] - g[:, :, i]).exp()
362365
last_recurrent_state = (
363366
last_recurrent_state * g_last
@@ -371,7 +374,7 @@ def patched_torch_chunk_gated_delta_rule(
371374
core_attn_out = core_attn_out[:, :, :num_heads]
372375
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
373376
summarize(core_attn_out, "attn_out")
374-
377+
375378
if isinstance(last_recurrent_state, torch.Tensor):
376379
summarize(last_recurrent_state, "state_out")
377380

@@ -615,28 +618,28 @@ def save_kv_cache(past_key_values, step_num, data_dir, model_name):
615618
"""Save KV cache tensors for each layer"""
616619
cache_dir = data_dir / f"kv_cache_step_{step_num}"
617620
cache_dir.mkdir(exist_ok=True)
618-
621+
619622
# Access past_key_values if available
620623
if past_key_values is not None:
621624
for layer_idx, cache_tuple in enumerate(past_key_values):
622625
if cache_tuple is None:
623626
print(f"Cache tuple is None for layer {layer_idx} at step {step_num}")
624627
continue
625-
628+
626629
# Handle different cache formats
627630
if isinstance(cache_tuple, (tuple, list)) and len(cache_tuple) >= 2:
628631
key, value = cache_tuple[0], cache_tuple[1]
629-
632+
630633
# Check if key and value are not None
631634
if key is not None and value is not None:
632635
# Save key cache
633636
key_filename = cache_dir / f"layer_{layer_idx}_key.bin"
634637
key.detach().cpu().numpy().astype(np.float32).tofile(key_filename)
635-
638+
636639
# Save value cache
637640
value_filename = cache_dir / f"layer_{layer_idx}_value.bin"
638641
value.detach().cpu().numpy().astype(np.float32).tofile(value_filename)
639-
642+
640643
print(f"Saved KV cache for layer {layer_idx} at step {step_num}: key.shape={key.shape}, value.shape={value.shape}")
641644
else:
642645
print(f"Key or value is None for layer {layer_idx} at step {step_num}")
@@ -738,67 +741,67 @@ def save_kv_cache(past_key_values, step_num, data_dir, model_name):
738741
print(f"\n=== Initial Forward Pass ===")
739742
outputs = model(input_ids, use_cache=True)
740743
logits = outputs.logits
741-
744+
742745
# Extract logits for the last token (next token prediction)
743746
last_logits = logits[0, -1, :].cpu().numpy()
744747
all_logits.append(last_logits)
745-
748+
746749
print(f"Logits shape: {logits.shape}")
747750
print(f"Last token logits shape: {last_logits.shape}")
748-
751+
749752
# Generate first token
750753
next_token_id = np.argmax(last_logits).item()
751754
all_generated_tokens.append(next_token_id)
752-
755+
753756
# Show top 5 predicted tokens for first step
754757
top_indices = np.argsort(last_logits)[-5:][::-1]
755758
print("Top 5 predictions for first token:")
756759
for idx in top_indices:
757760
token = tokenizer.decode([idx])
758761
print(f" Token {idx} ({repr(token)}): {last_logits[idx]:.6f}")
759-
762+
760763
print(f"Generated token {next_token_id} ({repr(tokenizer.decode([next_token_id]))})")
761-
764+
762765
# Save KV cache if requested
763766
if args.save_cache:
764767
save_kv_cache(outputs.past_key_values, 0, data_dir, model_name)
765-
768+
766769
# Prepare for next iteration
767770
past_key_values = outputs.past_key_values
768771
current_input = torch.tensor([[next_token_id]], device=device)
769-
772+
770773
# Generate remaining tokens
771774
for step in range(1, args.num_tokens):
772775
print(f"\n=== Generation Step {step} ===")
773-
776+
774777
# Forward pass with cache
775778
outputs = model(
776-
input_ids=current_input,
779+
input_ids=current_input,
777780
past_key_values=past_key_values,
778781
use_cache=True
779782
)
780-
783+
781784
logits = outputs.logits
782785
last_logits = logits[0, -1, :].cpu().numpy()
783786
all_logits.append(last_logits)
784-
787+
785788
# Generate next token
786789
next_token_id = np.argmax(last_logits).item()
787790
all_generated_tokens.append(next_token_id)
788-
791+
789792
# Show top 5 predicted tokens for this step
790793
top_indices = np.argsort(last_logits)[-5:][::-1]
791794
print(f"Top 5 predictions for step {step}:")
792795
for idx in top_indices:
793796
token = tokenizer.decode([idx])
794797
print(f" Token {idx} ({repr(token)}): {last_logits[idx]:.6f}")
795-
798+
796799
print(f"Generated token {next_token_id} ({repr(tokenizer.decode([next_token_id]))})")
797-
800+
798801
# Save KV cache if requested
799802
if args.save_cache:
800803
save_kv_cache(outputs.past_key_values, step, data_dir, model_name)
801-
804+
802805
# Update for next iteration
803806
past_key_values = outputs.past_key_values
804807
current_input = torch.tensor([[next_token_id]], device=device)
@@ -816,7 +819,7 @@ def save_kv_cache(past_key_values, step_num, data_dir, model_name):
816819
f.write(f"Generated tokens: {all_generated_tokens}\n")
817820
f.write(f"Generated text: {repr(tokenizer.decode(all_generated_tokens))}\n")
818821
f.write(f"Full sequence: {repr(tokenizer.decode(input_ids[0].tolist() + all_generated_tokens))}\n\n")
819-
822+
820823
for step, logits in enumerate(all_logits):
821824
f.write(f"=== Step {step} logits ===\n")
822825
for i, logit in enumerate(logits):
@@ -832,4 +835,4 @@ def save_kv_cache(past_key_values, step_num, data_dir, model_name):
832835
print(f"Saved txt logits to: {txt_filename}")
833836

834837
if args.save_cache:
835-
print(f"KV cache saved to: {data_dir}/kv_cache_step_*")
838+
print(f"KV cache saved to: {data_dir}/kv_cache_step_*")

0 commit comments

Comments
 (0)