@@ -686,19 +686,21 @@ def is_compute_bound(M: int, N: int, K: int, dtype: str, raw_accumulators: bool)
686686]
687687
688688
689- def llama8b_prefill (dtype : str ) -> list [GemmConfig ]:
689+ def llama8b_prefill (dtype : str , pad_contraction_dimension : bool ) -> list [GemmConfig ]:
690690 configs = []
691691 """LLAMA 8b Prefill, FP16."""
692692 for m , n , k , model in LLAMA :
693693 if model == "8b_prefill" :
694694 for raw_accumulators in [False , True ]:
695+ cache_line_size_bytes = 128
696+ padded_k = k + cache_line_size_bytes // num_bytes (dtype )
695697 configs .append (
696698 GemmConfig (
697699 m ,
698700 n ,
699- k ,
700- "T" ,
701+ padded_k if pad_contraction_dimension else k ,
701702 "N" ,
703+ "T" ,
702704 dtype ,
703705 get_default_accumulator_element_type (dtype ),
704706 get_default_result_element_type (
@@ -1036,7 +1038,8 @@ def square(dtype: str) -> list[GemmConfig]:
10361038
10371039
10381040def get_gemm_configs () -> list [tuple [str , GemmConfig ]]:
1039- llama8b_prefill_configs = llama8b_prefill ("f16" )
1041+ llama8b_prefill_configs = llama8b_prefill ("f16" , False )
1042+ llama8b_prefill_padded_configs = llama8b_prefill ("f16" , True )
10401043
10411044 llama13bmatvec_configs : list [GemmConfig ] = []
10421045 llama13bmatvec_configs += llama13bmatvec ("f16" )
@@ -1070,6 +1073,7 @@ def get_gemm_configs() -> list[tuple[str, GemmConfig]]:
10701073
10711074 all_configs : list [tuple [str , GemmConfig ]] = []
10721075 all_configs += [("llama8b_prefill" , x ) for x in llama8b_prefill_configs ]
1076+ all_configs += [("llama8b_prefill_padded" , x ) for x in llama8b_prefill_padded_configs ]
10731077 all_configs += [("llama13bmatvec" , x ) for x in llama13bmatvec_configs ]
10741078 all_configs += [("llama70bmatvec" , x ) for x in llama70bmatvec_configs ]
10751079 all_configs += [("llama13bskinny" , x ) for x in llama13bskinny_configs ]
0 commit comments