Skip to content

Commit f8d6ecd

Browse files
authored
Add llama3 8b prefill padded gemm shapes (#40)
Also fix the default transpose Issue: nod-ai/playbook#63
1 parent ac30afc commit f8d6ecd

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

gemmbench/problems.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -686,19 +686,21 @@ def is_compute_bound(M: int, N: int, K: int, dtype: str, raw_accumulators: bool)
686686
]
687687

688688

689-
def llama8b_prefill(dtype: str) -> list[GemmConfig]:
689+
def llama8b_prefill(dtype: str, pad_contraction_dimension: bool) -> list[GemmConfig]:
690690
configs = []
691691
"""LLAMA 8b Prefill, FP16."""
692692
for m, n, k, model in LLAMA:
693693
if model == "8b_prefill":
694694
for raw_accumulators in [False, True]:
695+
cache_line_size_bytes = 128
696+
padded_k = k + cache_line_size_bytes // num_bytes(dtype)
695697
configs.append(
696698
GemmConfig(
697699
m,
698700
n,
699-
k,
700-
"T",
701+
padded_k if pad_contraction_dimension else k,
701702
"N",
703+
"T",
702704
dtype,
703705
get_default_accumulator_element_type(dtype),
704706
get_default_result_element_type(
@@ -1036,7 +1038,8 @@ def square(dtype: str) -> list[GemmConfig]:
10361038

10371039

10381040
def get_gemm_configs() -> list[tuple[str, GemmConfig]]:
1039-
llama8b_prefill_configs = llama8b_prefill("f16")
1041+
llama8b_prefill_configs = llama8b_prefill("f16", False)
1042+
llama8b_prefill_padded_configs = llama8b_prefill("f16", True)
10401043

10411044
llama13bmatvec_configs: list[GemmConfig] = []
10421045
llama13bmatvec_configs += llama13bmatvec("f16")
@@ -1070,6 +1073,7 @@ def get_gemm_configs() -> list[tuple[str, GemmConfig]]:
10701073

10711074
all_configs: list[tuple[str, GemmConfig]] = []
10721075
all_configs += [("llama8b_prefill", x) for x in llama8b_prefill_configs]
1076+
all_configs += [("llama8b_prefill_padded", x) for x in llama8b_prefill_padded_configs]
10731077
all_configs += [("llama13bmatvec", x) for x in llama13bmatvec_configs]
10741078
all_configs += [("llama70bmatvec", x) for x in llama70bmatvec_configs]
10751079
all_configs += [("llama13bskinny", x) for x in llama13bskinny_configs]

0 commit comments

Comments
 (0)