Skip to content

Commit 4c8771d

Browse files
committed
Print 5D tensors
1 parent 10032af commit 4c8771d

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

examples/model-conversion/scripts/causal/run-org-model-multi-token.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,14 @@ def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int =
5050
- 2D tensors (seq, hidden)
5151
- 3D tensors (batch, seq, hidden)
5252
- 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head
53+
- 5D tensors
5354
5455
Shows first and last max_vals of each vector per sequence position.
5556
"""
5657
t = tensor.detach().to(torch.float32).cpu()
58+
ten_shape = t.shape
59+
while t.ndim > 4:
60+
t = t.squeeze(0)
5761

5862
# Determine dimensions
5963
if t.ndim == 3:
@@ -63,12 +67,11 @@ def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int =
6367
t = t.unsqueeze(0)
6468
elif t.ndim == 4:
6569
_, s, _, _ = t.shape
70+
6671
else:
6772
print(f"Skipping tensor due to unsupported dimensions: {t.ndim}")
6873
return
6974

70-
ten_shape = t.shape
71-
7275
print(f"ggml_debug: {name} = (f32) ... = {{{ten_shape}}}")
7376
print(" [")
7477
print(" [")

0 commit comments

Comments
 (0)