@@ -46,6 +46,13 @@ def bert_attn_sweep(dtype: str) -> list[AttentionConfig]:
4646 configs .append (AttentionConfig (B , M , N , K1 , K2 , dtype ))
4747 return configs
4848
49+ def llama3_405b_attn_sweep (dtype : str ) -> list [AttentionConfig ]:
50+ configs = []
51+ for M in [1024 , 2048 , 3072 , 4096 , 5120 , 6144 , 7168 , 8192 ]:
52+ K2 = M
53+ configs .append (AttentionConfig (512 , M , 128 , 128 , K2 , dtype ))
54+ M += 128
55+ return configs
4956
5057def get_attention_configs () -> list [tuple [str , AttentionConfig ]]:
5158 configs : list [tuple [str , AttentionConfig ]] = []
@@ -55,9 +62,12 @@ def get_attention_configs() -> list[tuple[str, AttentionConfig]]:
5562 sdxl_configs += sdxl_unet_sweep ("f8E4M3FNUZ" )
5663 bert_configs = bert_attn_sweep ("f16" )
5764 bert_configs += bert_attn_sweep ("f8E4M3FNUZ" )
65+ llama3_configs = llama3_405b_attn_sweep ("f16" )
66+ llama3_configs += llama3_405b_attn_sweep ("f8E4M3FNUZ" )
5867
59- configs += [("llm_sweep" , x ) for x in llm_configs ]
60- configs += [("sdxl_unet_sweep" , x ) for x in sdxl_configs ]
61- configs += [("bert_attn_sweep" , x ) for x in bert_configs ]
68+ configs += [("llm" , x ) for x in llm_configs ]
69+ configs += [("sdxl_unet" , x ) for x in sdxl_configs ]
70+ configs += [("bert" , x ) for x in bert_configs ]
71+ configs += [("llama3_405b" , x ) for x in llama3_configs ]
6272
6373 return configs
0 commit comments