|
1 | | -import ninetoothed |
2 | 1 | import torch |
3 | 2 | import triton |
4 | | -import triton.language as tl |
5 | | -from ninetoothed import Symbol, Tensor |
6 | 3 |
|
7 | | -BLOCK_SIZE = Symbol("BLOCK_SIZE", meta=True) |
8 | | - |
9 | | - |
10 | | -@ninetoothed.jit |
11 | | -def add_kernel( |
12 | | - lhs: Tensor(1).tile((BLOCK_SIZE,)), |
13 | | - rhs: Tensor(1).tile((BLOCK_SIZE,)), |
14 | | - output: Tensor(1).tile((BLOCK_SIZE,)), |
15 | | -): |
16 | | - output = lhs + rhs # noqa: F841 |
17 | | - |
18 | | - |
19 | | -def add(lhs, rhs): |
20 | | - output = torch.empty_like(lhs) |
21 | | - |
22 | | - add_kernel(lhs, rhs, output) |
23 | | - |
24 | | - return output |
25 | | - |
26 | | - |
27 | | -@triton.jit |
28 | | -def triton_add_kernel( |
29 | | - lhs_ptr, |
30 | | - rhs_ptr, |
31 | | - output_ptr, |
32 | | - n_elements, |
33 | | - BLOCK_SIZE: tl.constexpr, |
34 | | -): |
35 | | - pid = tl.program_id(0) |
36 | | - |
37 | | - block_start = pid * BLOCK_SIZE |
38 | | - offsets = block_start + tl.arange(0, BLOCK_SIZE) |
39 | | - mask = offsets < n_elements |
40 | | - |
41 | | - lhs = tl.load(lhs_ptr + offsets, mask=mask) |
42 | | - rhs = tl.load(rhs_ptr + offsets, mask=mask) |
43 | | - output = lhs + rhs |
44 | | - |
45 | | - tl.store(output_ptr + offsets, output, mask=mask) |
46 | | - |
47 | | - |
48 | | -def triton_add(lhs, rhs): |
49 | | - output = torch.empty_like(lhs) |
50 | | - n_elements = output.numel() |
51 | | - |
52 | | - def grid(meta): |
53 | | - return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
54 | | - |
55 | | - triton_add_kernel[grid](lhs, rhs, output, n_elements, BLOCK_SIZE=1024) |
56 | | - |
57 | | - return output |
58 | | - |
59 | | - |
60 | | -torch.manual_seed(0) |
61 | | -size = 98432 |
62 | | -lhs = torch.rand(size, device="cuda") |
63 | | -rhs = torch.rand(size, device="cuda") |
64 | | -ninetoothed_output = add(lhs, rhs) |
65 | | -torch_output = lhs + rhs |
66 | | -triton_output = triton_add(lhs, rhs) |
67 | | -print(ninetoothed_output) |
68 | | -print(torch_output) |
69 | | -print(triton_output) |
70 | | -if torch.allclose(ninetoothed_output, torch_output): |
71 | | - print("✅ NineToothed and PyTorch match.") |
72 | | -else: |
73 | | - print("❌ NineToothed and PyTorch differ.") |
74 | | -if torch.allclose(ninetoothed_output, triton_output): |
75 | | - print("✅ NineToothed and Triton match.") |
76 | | -else: |
77 | | - print("❌ NineToothed and Triton differ.") |
78 | | - |
79 | | - |
80 | | -@triton.testing.perf_report( |
81 | | - triton.testing.Benchmark( |
82 | | - x_names=["size"], |
83 | | - x_vals=[2**i for i in range(12, 28, 1)], |
84 | | - x_log=True, |
85 | | - line_arg="provider", |
86 | | - line_vals=["ninetoothed", "torch", "triton"], |
87 | | - line_names=["NineToothed", "PyTorch", "Triton"], |
88 | | - styles=[("blue", "-"), ("green", "-"), ("orange", "-")], |
89 | | - ylabel="GB/s", |
90 | | - plot_name="vector-addition-performance", |
91 | | - args={}, |
| 4 | +import ops.ninetoothed.torch |
| 5 | +import ops.triton.torch |
| 6 | + |
| 7 | +if __name__ == "__main__": |
| 8 | + torch.manual_seed(0) |
| 9 | + |
| 10 | + size = 98432 |
| 11 | + dtype = torch.float16 |
| 12 | + device = "cuda" |
| 13 | + |
| 14 | + input = torch.randn(size, dtype=dtype, device=device) |
| 15 | + other = torch.randn(size, dtype=dtype, device=device) |
| 16 | + |
| 17 | + ninetoothed_output = ops.ninetoothed.torch.add(input, other) |
| 18 | + torch_output = input + other |
| 19 | + triton_output = ops.triton.torch.add(input, other) |
| 20 | + |
| 21 | + print(ninetoothed_output) |
| 22 | + print(torch_output) |
| 23 | + print(triton_output) |
| 24 | + |
| 25 | + if torch.allclose(ninetoothed_output, torch_output): |
| 26 | + print("✅ NineToothed and PyTorch match.") |
| 27 | + else: |
| 28 | + print("❌ NineToothed and PyTorch differ.") |
| 29 | + if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0): |
| 30 | + print("✅ NineToothed and Triton match.") |
| 31 | + else: |
| 32 | + print("❌ NineToothed and Triton differ.") |
| 33 | + |
| 34 | + @triton.testing.perf_report( |
| 35 | + triton.testing.Benchmark( |
| 36 | + x_names=["size"], |
| 37 | + x_vals=[2**i for i in range(18, 28)], |
| 38 | + x_log=True, |
| 39 | + line_arg="provider", |
| 40 | + line_vals=["ninetoothed", "torch", "triton"], |
| 41 | + line_names=["NineToothed", "PyTorch", "Triton"], |
| 42 | + styles=[("blue", "-"), ("green", "-"), ("orange", "-")], |
| 43 | + ylabel="ms", |
| 44 | + plot_name="add-performance", |
| 45 | + args={}, |
| 46 | + ) |
92 | 47 | ) |
93 | | -) |
94 | | -def benchmark(size, provider): |
95 | | - lhs = torch.rand(size, device="cuda", dtype=torch.float32) |
96 | | - rhs = torch.rand(size, device="cuda", dtype=torch.float32) |
97 | | - quantiles = [0.5, 0.2, 0.8] |
| 48 | + def benchmark(size, provider): |
| 49 | + input = torch.randn(size, dtype=dtype, device=device) |
| 50 | + other = torch.randn(size, dtype=dtype, device=device) |
98 | 51 |
|
99 | | - if provider == "ninetoothed": |
100 | | - ms, min_ms, max_ms = triton.testing.do_bench( |
101 | | - lambda: add(lhs, rhs), quantiles=quantiles |
102 | | - ) |
103 | | - elif provider == "torch": |
104 | | - ms, min_ms, max_ms = triton.testing.do_bench( |
105 | | - lambda: lhs + rhs, quantiles=quantiles |
106 | | - ) |
107 | | - elif provider == "triton": |
108 | | - ms, min_ms, max_ms = triton.testing.do_bench( |
109 | | - lambda: triton_add(lhs, rhs), quantiles=quantiles |
110 | | - ) |
| 52 | + ninetoothed_output = ops.ninetoothed.torch.add(input, other) |
| 53 | + torch_output = torch.add(input, other) |
| 54 | + triton_output = ops.triton.torch.add(input, other) |
111 | 55 |
|
112 | | - def gbps(ms): |
113 | | - return 3 * lhs.numel() * lhs.element_size() / ms * 1e-6 |
| 56 | + assert torch.allclose(ninetoothed_output, torch_output) |
| 57 | + assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0) |
114 | 58 |
|
115 | | - return gbps(ms), gbps(max_ms), gbps(min_ms) |
| 59 | + if provider == "ninetoothed": |
| 60 | + ms = triton.testing.do_bench( |
| 61 | + lambda: ops.ninetoothed.torch.add(input, other) |
| 62 | + ) |
| 63 | + elif provider == "torch": |
| 64 | + ms = triton.testing.do_bench(lambda: torch.add(input, other)) |
| 65 | + elif provider == "triton": |
| 66 | + ms = triton.testing.do_bench(lambda: ops.triton.torch.add(input, other)) |
116 | 67 |
|
| 68 | + return ms |
117 | 69 |
|
118 | | -benchmark.run(print_data=True, show_plots=True, save_path=".") |
| 70 | + benchmark.run(print_data=True, show_plots=True, save_path=".") |
0 commit comments