Skip to content

Commit b220c76

Browse files
Add launch overhead microbenchmark (#7849)
Adding a script from @apgoucher to track dispatch overhead. The script never fails but it allows us to see the launch overhead whenever we have front end changes to see if there are significant changes --------- Co-authored-by: peterbell10 <[email protected]>
1 parent f971e9f commit b220c76

File tree

4 files changed

+105
-0
lines changed

4 files changed

+105
-0
lines changed

.github/workflows/integration-tests-amd.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,9 @@ jobs:
164164
# Reenable test_functional_regression.py once it's fixed
165165
cd python/test/regression
166166
python3 -m pytest -s -n 8 ./test_cast_matmul.py
167+
- name: Run microbenchmark tests
168+
run: |
169+
python3 python/test/microbenchmark/launch_overhead.py
167170
- name: Run Proton tests
168171
run: |
169172
unset HIP_VISIBLE_DEVICES

.github/workflows/integration-tests-nvidia.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ jobs:
9696
run: make test-interpret
9797
- name: Run regression tests
9898
run: make test-regression
99+
- name: Run microbenchmark tests
100+
# Microbenchmark never fail but running them gives us an easy way to track performance changes.
101+
run: make test-microbenchmark
99102
- name: Run C++ unittests
100103
run: make test-cpp
101104
- name: Run Proton tests

Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ test-gluon: all
6060
test-regression: all
6161
$(PYTEST) -s -n $(NUM_PROCS) python/test/regression
6262

63+
.PHONY: test-microbenchmark
64+
test-microbenchmark: all
65+
$(PYTHON) python/test/microbenchmark/launch_overhead.py
66+
6367
.PHONY: test-interpret
6468
test-interpret: all
6569
cd python/test/unit && TRITON_INTERPRET=1 $(PYTEST) -s -n 16 -m interpreter cuda language/test_core.py language/test_standard.py \
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
"""
2+
Original code by @bertmaher; profiling added by @apgoucher
3+
"""
4+
5+
import cProfile
6+
import pstats
7+
import time
8+
9+
import numpy as np
10+
import torch
11+
12+
import triton
13+
import triton.language as tl
14+
15+
16+
@triton.jit
17+
def nop_args(
18+
t1,
19+
t2,
20+
t3,
21+
t4,
22+
t5,
23+
i1,
24+
i2,
25+
i3,
26+
i4,
27+
i5,
28+
i6,
29+
i7,
30+
i8,
31+
i9,
32+
c1: tl.constexpr,
33+
c2: tl.constexpr,
34+
c3: tl.constexpr,
35+
c4: tl.constexpr,
36+
c5: tl.constexpr,
37+
):
38+
pass
39+
40+
41+
def do_bench_walltime(fn):
42+
print("Compiling...")
43+
fn()
44+
torch.cuda.synchronize()
45+
46+
for _ in range(1000):
47+
fn()
48+
torch.cuda.synchronize()
49+
50+
n_repeat = 10000
51+
52+
mses = []
53+
54+
for _ in range(25):
55+
print("Running %d benchmarking iterations..." % n_repeat)
56+
# Benchmark
57+
torch.cuda.synchronize()
58+
start_time = time.time()
59+
for _ in range(n_repeat):
60+
fn()
61+
torch.cuda.synchronize()
62+
end_time = time.time()
63+
wall_time_ms = (end_time - start_time) * 1e3 / n_repeat
64+
mses.append(wall_time_ms)
65+
66+
mses = np.array(mses)
67+
68+
print("Running profiler...")
69+
profile = cProfile.Profile()
70+
profile.enable()
71+
for _ in range(n_repeat):
72+
fn()
73+
torch.cuda.synchronize()
74+
profile.disable()
75+
stats = pstats.Stats(profile)
76+
stats.sort_stats("time")
77+
stats.print_stats()
78+
return mses
79+
80+
81+
def main():
82+
targs = [torch.zeros(1, device="cuda") for _ in range(5)]
83+
iargs = [1 for _ in range(9)]
84+
cargs = [32 for _ in range(5)]
85+
86+
usecs = do_bench_walltime(lambda: nop_args[
87+
1,
88+
](*targs, *iargs, *cargs)) * 1000.0
89+
90+
print(usecs)
91+
print(sorted(usecs)[len(usecs) >> 1])
92+
93+
94+
if __name__ == "__main__":
95+
main()

0 commit comments

Comments
 (0)