Skip to content

Commit 26652c1

Browse files
committed
Add mm_softmax_mm.py
1 parent 0600f3e commit 26652c1

File tree

1 file changed

+129
-0
lines changed

1 file changed

+129
-0
lines changed

mm_softmax_mm.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import ninetoothed
2+
import torch
3+
import triton
4+
5+
import ops.ninetoothed.torch
6+
7+
8+
def ninetoothed_mm_softmax_mm(a, b, c, use_fused=True):
9+
if use_fused:
10+
return ninetoothed_mm_softmax_mm_fused(a, b, c)
11+
12+
return ninetoothed_mm_softmax_mm_unfused(a, b, c)
13+
14+
15+
@torch.compile(backend=ninetoothed.fuser)
16+
def ninetoothed_mm_softmax_mm_fused(a, b, c):
17+
return ninetoothed_mm_softmax_mm_unfused(a, b, c)
18+
19+
20+
def ninetoothed_mm_softmax_mm_unfused(a, b, c):
21+
return ninetoothed_mm(ninetoothed_softmax(ninetoothed_mm(a, b)), c)
22+
23+
24+
def ninetoothed_softmax(a):
25+
return ops.ninetoothed.torch.softmax(a, impl_id=1)
26+
27+
28+
def ninetoothed_mm(a, b):
29+
return ops.ninetoothed.torch.mm(a, b, impl_id=1)
30+
31+
32+
def torch_mm_softmax_mm(a, b, c, use_compiled=True):
33+
if use_compiled:
34+
return torch_mm_softmax_mm_compiled(a, b, c)
35+
36+
return torch_mm_softmax_mm_uncompiled(a, b, c)
37+
38+
39+
def torch_mm_softmax_mm_uncompiled(a, b, c):
40+
return torch.mm(torch.softmax(torch.mm(a, b), dim=-1), c)
41+
42+
43+
@torch.compile
44+
def torch_mm_softmax_mm_compiled(a, b, c):
45+
return torch_mm_softmax_mm_uncompiled(a, b, c)
46+
47+
48+
if __name__ == "__main__":
49+
torch.manual_seed(0)
50+
51+
shape = (512, 512)
52+
dtype = torch.float16
53+
device = "cuda"
54+
55+
a = torch.randn(shape, dtype=dtype, device=device)
56+
b = torch.randn(shape, dtype=dtype, device=device)
57+
c = torch.randn(shape, dtype=dtype, device=device)
58+
59+
ninetoothed_output = ninetoothed_mm_softmax_mm(a, b, c)
60+
torch_output = torch_mm_softmax_mm(a, b, c)
61+
62+
print(ninetoothed_output)
63+
print(torch_output)
64+
65+
if torch.allclose(ninetoothed_output, torch_output, rtol=1e-3, atol=1e-3):
66+
print("✅ NineToothed and PyTorch match.")
67+
else:
68+
print("❌ NineToothed and PyTorch differ.")
69+
70+
@triton.testing.perf_report(
71+
triton.testing.Benchmark(
72+
x_names=["m"],
73+
x_vals=[2**i for i in range(3, 13)],
74+
x_log=True,
75+
line_arg="provider",
76+
line_vals=[
77+
"ninetoothed_fused",
78+
"torch_compiled",
79+
"ninetoothed_unfused",
80+
"torch_uncompiled",
81+
],
82+
line_names=[
83+
"NineToothed (Fused)",
84+
"PyTorch (Compiled)",
85+
"NineToothed (Unfused)",
86+
"PyTorch (Uncompiled)",
87+
],
88+
styles=[("blue", "-"), ("green", "-"), ("orange", "-"), ("purple", "-")],
89+
ylabel="ms",
90+
plot_name="mm-softmax-mm-performance",
91+
args={"n": 128},
92+
)
93+
)
94+
def benchmark(m, n, provider):
95+
a = torch.randn((m, n), dtype=dtype, device=device)
96+
b = torch.randn((n, m), dtype=dtype, device=device)
97+
c = torch.randn((m, n), dtype=dtype, device=device)
98+
99+
if provider == "ninetoothed_fused":
100+
ninetoothed_output = ninetoothed_mm_softmax_mm(a, b, c)
101+
torch_output = torch_mm_softmax_mm(a, b, c)
102+
103+
assert torch.allclose(
104+
ninetoothed_output, torch_output, rtol=1e-3, atol=1e-3
105+
)
106+
elif provider == "ninetoothed_unfused":
107+
ninetoothed_output = ninetoothed_mm_softmax_mm(a, b, c, use_fused=False)
108+
torch_output = torch_mm_softmax_mm(a, b, c, use_compiled=False)
109+
110+
assert torch.allclose(
111+
ninetoothed_output, torch_output, rtol=1e-3, atol=1e-3
112+
)
113+
114+
if provider == "ninetoothed_fused":
115+
ms = triton.testing.do_bench(lambda: ninetoothed_mm_softmax_mm(a, b, c))
116+
elif provider == "torch_compiled":
117+
ms = triton.testing.do_bench(lambda: torch_mm_softmax_mm(a, b, c))
118+
elif provider == "ninetoothed_unfused":
119+
ms = triton.testing.do_bench(
120+
lambda: ninetoothed_mm_softmax_mm(a, b, c, use_fused=False)
121+
)
122+
elif provider == "torch_uncompiled":
123+
ms = triton.testing.do_bench(
124+
lambda: torch_mm_softmax_mm(a, b, c, use_compiled=False)
125+
)
126+
127+
return ms
128+
129+
benchmark.run(show_plots=True, print_data=True, save_path=".")

0 commit comments

Comments
 (0)