Skip to content

Commit f75f841

Browse files
committed
Add compare_performance_metrics.py
1 parent 276c5ab commit f75f841

File tree

2 files changed

+128
-93
lines changed

2 files changed

+128
-93
lines changed

compare_performance_metrics.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import functools
2+
import random
3+
4+
import matplotlib.pyplot as plt
5+
import pandas as pd
6+
import torch
7+
import torch.nn.functional
8+
import triton
9+
10+
import ops.ninetoothed.torch
11+
import ops.triton.torch
12+
import rotary_position_embedding
13+
from compare_code_metrics import _BACKSLASH_CHAR
14+
15+
16+
def _run_task(op_name, dtype, device, *arg_shapes, **kwarg_shapes):
17+
ninetoothed_op = getattr(ops.ninetoothed.torch, op_name)
18+
triton_op = getattr(ops.triton.torch, op_name)
19+
20+
if op_name == "rotary_position_embedding":
21+
torch_op = rotary_position_embedding.torch_rotary_position_embedding
22+
else:
23+
torch_op = (
24+
getattr(torch, op_name)
25+
if hasattr(torch, op_name)
26+
else getattr(torch.nn.functional, op_name)
27+
)
28+
29+
if op_name == "rms_norm":
30+
torch_op = functools.partial(torch_op, normalized_shape=arg_shapes[0][-1:])
31+
elif op_name == "softmax":
32+
torch_op = functools.partial(torch_op, dim=-1)
33+
34+
args = tuple(
35+
torch.randn(shape, dtype=dtype, device=device) if shape else random.gauss(0, 1)
36+
for shape in arg_shapes
37+
)
38+
kwargs = {
39+
key: torch.randn(shape, dtype=dtype, device=device)
40+
if shape
41+
else random.gauss(0, 1)
42+
for key, shape in kwarg_shapes.items()
43+
}
44+
45+
arg_shape_string = ", ".join(str(shape) for shape in arg_shapes)
46+
kwarg_shape_string = ", ".join(
47+
f"{key}={shape}" for key, shape in kwarg_shapes.items()
48+
)
49+
shape_string = (
50+
f"{arg_shape_string}, {kwarg_shape_string}"
51+
if kwarg_shape_string
52+
else arg_shape_string
53+
)
54+
55+
task_description = f"{op_name}({shape_string})"
56+
57+
return task_description, _benchmark_ops(
58+
(ninetoothed_op, triton_op, torch_op), *args, **kwargs
59+
)
60+
61+
62+
def _benchmark_ops(ops, *args, **kwargs):
63+
assert all(
64+
torch.allclose(
65+
op(*args, **kwargs), ops[0](*args, **kwargs), rtol=0.01, atol=0.01
66+
)
67+
for op in ops[1:]
68+
)
69+
70+
return tuple(triton.testing.do_bench(lambda: op(*args, **kwargs)) for op in ops)
71+
72+
73+
if __name__ == "__main__":
74+
random.seed(0)
75+
torch.manual_seed(0)
76+
77+
plt.rcParams["figure.dpi"] = 600
78+
plt.rcParams["font.family"] = "Linux Biolinum"
79+
80+
dtype = torch.float16
81+
device = "cuda"
82+
83+
tasks = (
84+
("add", ((4096 * 4096,), (4096 * 4096,)), {}),
85+
(
86+
"addmm",
87+
((4096, 4096), (4096, 4096), (4096, 4096)),
88+
{"beta": (), "alpha": ()},
89+
),
90+
("bmm", ((4, 2048, 2048), (4, 2048, 2048)), {}),
91+
("conv2d", ((4, 512, 14, 14), (512, 512, 3, 3)), {}),
92+
("mm", ((4096, 4096), (4096, 4096)), {}),
93+
("rms_norm", ((4096, 4096),), {}),
94+
("rotary_position_embedding", ((4, 1024, 48, 64), (1024, 32), (1024, 32)), {}),
95+
(
96+
"scaled_dot_product_attention",
97+
((4, 48, 1024, 64), (4, 48, 1024, 64), (4, 48, 1024, 64)),
98+
{},
99+
),
100+
("silu", ((4096 * 4096,),), {}),
101+
("softmax", ((4096, 4096),), {}),
102+
)
103+
104+
data = {"Task": [], "NineToothed": [], "Triton": [], "PyTorch": []}
105+
106+
for name, args, kwargs in tasks:
107+
description, results = _run_task(name, dtype, device, *args, **kwargs)
108+
109+
latex_item = f"\item {_BACKSLASH_CHAR}texttt{{{description.replace('scaled_dot_product_attention', 'sdpa').replace('rotary_position_embedding', 'rope').replace('_', f'{_BACKSLASH_CHAR}_')}}}"
110+
111+
print(latex_item)
112+
113+
data["Task"].append(description)
114+
115+
for i, provider in enumerate(("NineToothed", "Triton", "PyTorch")):
116+
data[provider].append(results[i])
117+
118+
df = pd.DataFrame(data)
119+
df.index += 1
120+
121+
df.set_index("Task").to_csv("performance-metrics.csv")
122+
123+
df.plot(kind="bar", rot=0)
124+
plt.ylabel("Execution Time (ms)")
125+
plt.xlabel("Task")
126+
plt.grid(False)
127+
plt.tight_layout()
128+
plt.savefig("performance-metrics.png")

performance_comparison.py

Lines changed: 0 additions & 93 deletions
This file was deleted.

0 commit comments

Comments
 (0)