@@ -619,6 +619,10 @@ def is_compute_bound(M: int, N: int, K: int, dtype: str, raw_accumulators: bool)
619619 (5120 , 16384 , 640 , "13b" ),
620620 (3456 , 16384 , 5120 , "13b" ),
621621 (5120 , 16384 , 1728 , "13b" ),
622+ (512 , 4096 , 14336 , "8b_prefill" ),
623+ (512 , 14336 , 4096 , "8b_prefill" ),
624+ (512 , 4096 , 4096 , "8b_prefill" ),
625+ (512 , 1024 , 4096 , "8b_prefill" ),
622626]
623627
624628GPT4 = [
@@ -682,6 +686,28 @@ def is_compute_bound(M: int, N: int, K: int, dtype: str, raw_accumulators: bool)
682686]
683687
684688
689+ def llama8b_prefill (dtype : str ) -> list [GemmConfig ]:
690+ configs = []
691+ """LLAMA 8b Prefill, FP16."""
692+ for m , n , k , model in LLAMA :
693+ if model == "8b_prefill" :
694+ for raw_accumulators in [False , True ]:
695+ configs .append (
696+ GemmConfig (
697+ m ,
698+ n ,
699+ k ,
700+ "T" ,
701+ "N" ,
702+ dtype ,
703+ get_default_accumulator_element_type (dtype ),
704+ get_default_result_element_type (
705+ dtype , raw_accumulators ),
706+ )
707+ )
708+ return configs
709+
710+
685711def llama13bmatvec (dtype : str ) -> list [GemmConfig ]:
686712 configs = []
687713 """LLAMA 13b, single batch, FP16."""
@@ -1010,6 +1036,8 @@ def square(dtype: str) -> list[GemmConfig]:
10101036
10111037
10121038def get_gemm_configs () -> list [tuple [str , GemmConfig ]]:
1039+ llama8b_prefill_configs = llama8b_prefill ("f16" )
1040+
10131041 llama13bmatvec_configs : list [GemmConfig ] = []
10141042 llama13bmatvec_configs += llama13bmatvec ("f16" )
10151043 llama13bmatvec_configs += llama13bmatvecbf16 ("bf16" )
@@ -1041,6 +1069,7 @@ def get_gemm_configs() -> list[tuple[str, GemmConfig]]:
10411069 square_configs : list [GemmConfig ] = square ("f16" ) + square ("bf16" ) + square ("i8" )
10421070
10431071 all_configs : list [tuple [str , GemmConfig ]] = []
1072+ all_configs += [("llama8b_prefill" , x ) for x in llama8b_prefill_configs ]
10441073 all_configs += [("llama13bmatvec" , x ) for x in llama13bmatvec_configs ]
10451074 all_configs += [("llama70bmatvec" , x ) for x in llama70bmatvec_configs ]
10461075 all_configs += [("llama13bskinny" , x ) for x in llama13bskinny_configs ]
0 commit comments