Skip to content

Commit 982eb72

Browse files
authored
Add llama3 405b attention shapes. (#29)
This PR adds the llama3 405b attention shapes that we see in the sharktank export (https://gist.github.com/KyleHerndon/a9c60ce93264d6ba7ec9e878c879f218). We make sure the dynamic sequence length is always a multiple of 16
1 parent 3f3c514 commit 982eb72

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

attentionbench/problems.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ def bert_attn_sweep(dtype: str) -> list[AttentionConfig]:
4646
configs.append(AttentionConfig(B, M, N, K1, K2, dtype))
4747
return configs
4848

49+
def llama3_405b_attn_sweep(dtype: str) -> list[AttentionConfig]:
50+
configs = []
51+
for M in [1024, 2048, 3072, 4096, 5120, 6144, 7168, 8192]:
52+
K2 = M
53+
configs.append(AttentionConfig(512, M, 128, 128, K2, dtype))
54+
M += 128
55+
return configs
4956

5057
def get_attention_configs() -> list[tuple[str, AttentionConfig]]:
5158
configs: list[tuple[str, AttentionConfig]] = []
@@ -55,9 +62,12 @@ def get_attention_configs() -> list[tuple[str, AttentionConfig]]:
5562
sdxl_configs += sdxl_unet_sweep("f8E4M3FNUZ")
5663
bert_configs = bert_attn_sweep("f16")
5764
bert_configs += bert_attn_sweep("f8E4M3FNUZ")
65+
llama3_configs = llama3_405b_attn_sweep("f16")
66+
llama3_configs += llama3_405b_attn_sweep("f8E4M3FNUZ")
5867

59-
configs += [("llm_sweep", x) for x in llm_configs]
60-
configs += [("sdxl_unet_sweep", x) for x in sdxl_configs]
61-
configs += [("bert_attn_sweep", x) for x in bert_configs]
68+
configs += [("llm", x) for x in llm_configs]
69+
configs += [("sdxl_unet", x) for x in sdxl_configs]
70+
configs += [("bert", x) for x in bert_configs]
71+
configs += [("llama3_405b", x) for x in llama3_configs]
6272

6373
return configs

0 commit comments

Comments
 (0)