Skip to content

Commit ac30afc

Browse files
authored
Add llama3 8b prefill gemm shapes (#39)
1 parent a70b3df commit ac30afc

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

gemmbench/problems.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,10 @@ def is_compute_bound(M: int, N: int, K: int, dtype: str, raw_accumulators: bool)
619619
(5120, 16384, 640, "13b"),
620620
(3456, 16384, 5120, "13b"),
621621
(5120, 16384, 1728, "13b"),
622+
(512, 4096, 14336, "8b_prefill"),
623+
(512, 14336, 4096, "8b_prefill"),
624+
(512, 4096, 4096, "8b_prefill"),
625+
(512, 1024, 4096, "8b_prefill"),
622626
]
623627

624628
GPT4 = [
@@ -682,6 +686,28 @@ def is_compute_bound(M: int, N: int, K: int, dtype: str, raw_accumulators: bool)
682686
]
683687

684688

689+
def llama8b_prefill(dtype: str) -> list[GemmConfig]:
690+
configs = []
691+
"""LLAMA 8b Prefill, FP16."""
692+
for m, n, k, model in LLAMA:
693+
if model == "8b_prefill":
694+
for raw_accumulators in [False, True]:
695+
configs.append(
696+
GemmConfig(
697+
m,
698+
n,
699+
k,
700+
"T",
701+
"N",
702+
dtype,
703+
get_default_accumulator_element_type(dtype),
704+
get_default_result_element_type(
705+
dtype, raw_accumulators),
706+
)
707+
)
708+
return configs
709+
710+
685711
def llama13bmatvec(dtype: str) -> list[GemmConfig]:
686712
configs = []
687713
"""LLAMA 13b, single batch, FP16."""
@@ -1010,6 +1036,8 @@ def square(dtype: str) -> list[GemmConfig]:
10101036

10111037

10121038
def get_gemm_configs() -> list[tuple[str, GemmConfig]]:
1039+
llama8b_prefill_configs = llama8b_prefill("f16")
1040+
10131041
llama13bmatvec_configs: list[GemmConfig] = []
10141042
llama13bmatvec_configs += llama13bmatvec("f16")
10151043
llama13bmatvec_configs += llama13bmatvecbf16("bf16")
@@ -1041,6 +1069,7 @@ def get_gemm_configs() -> list[tuple[str, GemmConfig]]:
10411069
square_configs: list[GemmConfig] = square("f16") + square("bf16") + square("i8")
10421070

10431071
all_configs: list[tuple[str, GemmConfig]] = []
1072+
all_configs += [("llama8b_prefill", x) for x in llama8b_prefill_configs]
10441073
all_configs += [("llama13bmatvec", x) for x in llama13bmatvec_configs]
10451074
all_configs += [("llama70bmatvec", x) for x in llama70bmatvec_configs]
10461075
all_configs += [("llama13bskinny", x) for x in llama13bskinny_configs]

0 commit comments

Comments
 (0)