Skip to content

Commit 3a39afe

Browse files
authored
Merge pull request #7 from InfiniTensor/experiment
Refactor the code to incorporate more comprehensive experiments
2 parents c6a5a62 + 7798e39 commit 3a39afe

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+2986
-1892
lines changed

.gitignore

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,10 @@ cython_debug/
160160
# and can be added to the global gitignore or merged into this file. For a more nuclear
161161
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
162162
#.idea/
163+
164+
# Evaluation results
165+
*.csv
166+
*.html
167+
*.json
168+
*.png
169+
*.tex

README.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
# NineToothed Examples
22

3-
This repository contains examples for [NineToothed](https://github.com/InfiniTensor/ninetoothed), including implementations of several common compute kernels written using NineToothed.
3+
This repository contains examples of [NineToothed](https://github.com/InfiniTensor/ninetoothed), including implementations of several common compute kernels written using NineToothed.
44

55
## Usage
66

77
After cloning this repository, you can run any of the examples using Python. For instance, to run the matrix multiplication example, execute the following command:
88

99
```bash
10-
python matmul.py
10+
python mm.py
1111
```
1212

1313
### Autotuning Behavior
1414

15-
By default, the examples apply autotuning, which may take several minutes or longer to complete for complex kernels. If you wish to disable autotuning, you can replace symbol definitions with concrete values. Consider the following example:
15+
Some examples apply autotuning, which may take several minutes or longer to complete for complex kernels. If you wish to disable autotuning, you can replace symbol definitions with concrete values.
16+
17+
Consider the following example:
1618

1719
```python
1820
BLOCK_SIZE = Symbol("BLOCK_SIZE", meta=True)
@@ -29,6 +31,8 @@ BLOCK_SIZE = 1024
2931

3032
These approaches allow you to obtain results in seconds. However, selecting optimal values is crucial for good performance. Experiment with different values to determine the best configuration.
3133

34+
Note: Please don't forget to also disable the autotuning of the corresponding Triton compute kernels.
35+
3236
## Third-Party Code and Licenses
3337

3438
This project includes code modified or inspired from the following open-source repositories:

add.py

Lines changed: 61 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -1,118 +1,70 @@
1-
import ninetoothed
21
import torch
32
import triton
4-
import triton.language as tl
5-
from ninetoothed import Symbol, Tensor
63

7-
BLOCK_SIZE = Symbol("BLOCK_SIZE", meta=True)
8-
9-
10-
@ninetoothed.jit
11-
def add_kernel(
12-
lhs: Tensor(1).tile((BLOCK_SIZE,)),
13-
rhs: Tensor(1).tile((BLOCK_SIZE,)),
14-
output: Tensor(1).tile((BLOCK_SIZE,)),
15-
):
16-
output = lhs + rhs # noqa: F841
17-
18-
19-
def add(lhs, rhs):
20-
output = torch.empty_like(lhs)
21-
22-
add_kernel(lhs, rhs, output)
23-
24-
return output
25-
26-
27-
@triton.jit
28-
def triton_add_kernel(
29-
lhs_ptr,
30-
rhs_ptr,
31-
output_ptr,
32-
n_elements,
33-
BLOCK_SIZE: tl.constexpr,
34-
):
35-
pid = tl.program_id(0)
36-
37-
block_start = pid * BLOCK_SIZE
38-
offsets = block_start + tl.arange(0, BLOCK_SIZE)
39-
mask = offsets < n_elements
40-
41-
lhs = tl.load(lhs_ptr + offsets, mask=mask)
42-
rhs = tl.load(rhs_ptr + offsets, mask=mask)
43-
output = lhs + rhs
44-
45-
tl.store(output_ptr + offsets, output, mask=mask)
46-
47-
48-
def triton_add(lhs, rhs):
49-
output = torch.empty_like(lhs)
50-
n_elements = output.numel()
51-
52-
def grid(meta):
53-
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
54-
55-
triton_add_kernel[grid](lhs, rhs, output, n_elements, BLOCK_SIZE=1024)
56-
57-
return output
58-
59-
60-
torch.manual_seed(0)
61-
size = 98432
62-
lhs = torch.rand(size, device="cuda")
63-
rhs = torch.rand(size, device="cuda")
64-
ninetoothed_output = add(lhs, rhs)
65-
torch_output = lhs + rhs
66-
triton_output = triton_add(lhs, rhs)
67-
print(ninetoothed_output)
68-
print(torch_output)
69-
print(triton_output)
70-
if torch.allclose(ninetoothed_output, torch_output):
71-
print("✅ NineToothed and PyTorch match.")
72-
else:
73-
print("❌ NineToothed and PyTorch differ.")
74-
if torch.allclose(ninetoothed_output, triton_output):
75-
print("✅ NineToothed and Triton match.")
76-
else:
77-
print("❌ NineToothed and Triton differ.")
78-
79-
80-
@triton.testing.perf_report(
81-
triton.testing.Benchmark(
82-
x_names=["size"],
83-
x_vals=[2**i for i in range(12, 28, 1)],
84-
x_log=True,
85-
line_arg="provider",
86-
line_vals=["ninetoothed", "torch", "triton"],
87-
line_names=["NineToothed", "PyTorch", "Triton"],
88-
styles=[("blue", "-"), ("green", "-"), ("orange", "-")],
89-
ylabel="GB/s",
90-
plot_name="vector-addition-performance",
91-
args={},
4+
import ops.ninetoothed.torch
5+
import ops.triton.torch
6+
7+
if __name__ == "__main__":
8+
torch.manual_seed(0)
9+
10+
size = 98432
11+
dtype = torch.float16
12+
device = "cuda"
13+
14+
input = torch.randn(size, dtype=dtype, device=device)
15+
other = torch.randn(size, dtype=dtype, device=device)
16+
17+
ninetoothed_output = ops.ninetoothed.torch.add(input, other)
18+
torch_output = input + other
19+
triton_output = ops.triton.torch.add(input, other)
20+
21+
print(ninetoothed_output)
22+
print(torch_output)
23+
print(triton_output)
24+
25+
if torch.allclose(ninetoothed_output, torch_output):
26+
print("✅ NineToothed and PyTorch match.")
27+
else:
28+
print("❌ NineToothed and PyTorch differ.")
29+
if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0):
30+
print("✅ NineToothed and Triton match.")
31+
else:
32+
print("❌ NineToothed and Triton differ.")
33+
34+
@triton.testing.perf_report(
35+
triton.testing.Benchmark(
36+
x_names=["size"],
37+
x_vals=[2**i for i in range(18, 28)],
38+
x_log=True,
39+
line_arg="provider",
40+
line_vals=["ninetoothed", "torch", "triton"],
41+
line_names=["NineToothed", "PyTorch", "Triton"],
42+
styles=[("blue", "-"), ("green", "-"), ("orange", "-")],
43+
ylabel="ms",
44+
plot_name="add-performance",
45+
args={},
46+
)
9247
)
93-
)
94-
def benchmark(size, provider):
95-
lhs = torch.rand(size, device="cuda", dtype=torch.float32)
96-
rhs = torch.rand(size, device="cuda", dtype=torch.float32)
97-
quantiles = [0.5, 0.2, 0.8]
48+
def benchmark(size, provider):
49+
input = torch.randn(size, dtype=dtype, device=device)
50+
other = torch.randn(size, dtype=dtype, device=device)
9851

99-
if provider == "ninetoothed":
100-
ms, min_ms, max_ms = triton.testing.do_bench(
101-
lambda: add(lhs, rhs), quantiles=quantiles
102-
)
103-
elif provider == "torch":
104-
ms, min_ms, max_ms = triton.testing.do_bench(
105-
lambda: lhs + rhs, quantiles=quantiles
106-
)
107-
elif provider == "triton":
108-
ms, min_ms, max_ms = triton.testing.do_bench(
109-
lambda: triton_add(lhs, rhs), quantiles=quantiles
110-
)
52+
ninetoothed_output = ops.ninetoothed.torch.add(input, other)
53+
torch_output = torch.add(input, other)
54+
triton_output = ops.triton.torch.add(input, other)
11155

112-
def gbps(ms):
113-
return 3 * lhs.numel() * lhs.element_size() / ms * 1e-6
56+
assert torch.allclose(ninetoothed_output, torch_output)
57+
assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0)
11458

115-
return gbps(ms), gbps(max_ms), gbps(min_ms)
59+
if provider == "ninetoothed":
60+
ms = triton.testing.do_bench(
61+
lambda: ops.ninetoothed.torch.add(input, other)
62+
)
63+
elif provider == "torch":
64+
ms = triton.testing.do_bench(lambda: torch.add(input, other))
65+
elif provider == "triton":
66+
ms = triton.testing.do_bench(lambda: ops.triton.torch.add(input, other))
11667

68+
return ms
11769

118-
benchmark.run(print_data=True, show_plots=True, save_path=".")
70+
benchmark.run(print_data=True, show_plots=True, save_path=".")

0 commit comments

Comments
 (0)