Skip to content

Commit 9d006c9

Browse files
committed
Add performance comparison
1 parent 83f7596 commit 9d006c9

File tree

1 file changed

+87
-0
lines changed

1 file changed

+87
-0
lines changed

performance_comparison.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from dataclasses import dataclass
2+
3+
import matplotlib.pyplot as plt
4+
import pandas as pd
5+
6+
plt.rcParams["figure.figsize"] = [12, 6]
7+
plt.rcParams["figure.dpi"] = 600
8+
plt.rcParams["font.family"] = "JetBrains Mono"
9+
plt.rcParams["font.weight"] = "bold"
10+
plt.rcParams["axes.titleweight"] = "bold"
11+
plt.rcParams["axes.labelweight"] = "bold"
12+
13+
14+
@dataclass
15+
class KernelInformation:
16+
name: str
17+
memory_bound: bool
18+
compute_bound: bool
19+
perf_report_path: str
20+
independent_variable: str
21+
22+
23+
@dataclass
24+
class CategoryInformation:
25+
kernels: tuple
26+
y_label: str
27+
28+
29+
kernels = (
30+
KernelInformation("add", True, False, "vector-addition-performance.csv", "Length"),
31+
KernelInformation(
32+
"softmax", True, False, "softmax-performance.csv", "Number of Columns"
33+
),
34+
KernelInformation(
35+
"rms_norm", True, False, "rms-norm-performance.csv", "Number of Columns"
36+
),
37+
KernelInformation(
38+
"matmul", False, True, "matrix-multiplication-performance.csv", "Sizes"
39+
),
40+
KernelInformation(
41+
"conv2d", False, True, "2d-convolution-performance.csv", "Batch Size"
42+
),
43+
KernelInformation(
44+
"attention", False, True, "attention-performance.csv", "Sequence Length"
45+
),
46+
)
47+
48+
providers = ("Triton", "NineToothed")
49+
50+
categories = (
51+
CategoryInformation(
52+
tuple(kernel for kernel in kernels if kernel.memory_bound), "GB/s"
53+
),
54+
CategoryInformation(
55+
tuple(kernel for kernel in kernels if kernel.compute_bound), "TFLOPS"
56+
),
57+
)
58+
59+
num_rows = len(categories)
60+
num_cols = max(len(category.kernels) for category in categories)
61+
62+
fig, axs = plt.subplots(num_rows, num_cols)
63+
64+
for row, category in enumerate(categories):
65+
axs[row, 0].set_ylabel(category.y_label)
66+
67+
for col, kernel in enumerate(category.kernels):
68+
df = pd.read_csv(kernel.perf_report_path)
69+
ax = axs[row, col]
70+
71+
x = df.iloc[:, 0]
72+
73+
for provider in providers:
74+
y = df[provider]
75+
76+
ax.plot(x, y, label=provider)
77+
78+
ax.set_title(kernel.name)
79+
ax.set_xlabel(kernel.independent_variable)
80+
ax.set_xscale("log", base=2)
81+
82+
fig.legend(providers, loc="upper center", ncols=len(providers))
83+
fig.tight_layout()
84+
fig.subplots_adjust(top=0.9)
85+
86+
plt.show()
87+
plt.savefig("performance-comparison.png")

0 commit comments

Comments
 (0)