Skip to content

Commit 5d295e0

Browse files
committed
add note
1 parent 2ae7084 commit 5d295e0

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse_stage2.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
def create_tensors(shared_seq_len):
77
batch_size = 4
8-
num_heads = 4
9-
kv_head_num = 1
8+
num_heads = 32
9+
kv_head_num = 8
1010
seq_len = 256
1111
head_dim = 128
1212
max_len_in_batch = seq_len
@@ -113,6 +113,14 @@ def test_flash_decode_stage2_execution(shared_seq_len):
113113
print(f"\nshared_seq_len={shared_seq_len}")
114114
print(f"mid_out: {mid_out[0:4, 0, 0, 0]}")
115115
print(f"true_mid_out: {true_mid_out[0:4, 0, 0, 0]}")
116+
abs_diff = (mid_out - true_mid_out).abs()
117+
max_diff = abs_diff.max()
118+
max_diff_idx = abs_diff.argmax()
119+
max_diff_idx_unraveled = torch.unravel_index(max_diff_idx, abs_diff.shape)
120+
mid_out_value = mid_out[max_diff_idx_unraveled]
121+
true_mid_out_value = true_mid_out[max_diff_idx_unraveled]
122+
print(f"max abs diff: {max_diff}, mid_out value: {mid_out_value}, " f"true_mid_out value: {true_mid_out_value}")
123+
116124
assert torch.allclose(
117125
mid_out[0:4, 0, 0, 0], true_mid_out[0:4, 0, 0, 0], atol=1e-2
118126
), f"Mid output does not match expected values for shared_seq_len={shared_seq_len}"

0 commit comments

Comments
 (0)