Skip to content

Commit 5cb9403

Browse files
committed
tests/bench: add a test for forward and backward together
1 parent fb1951e commit 5cb9403

File tree

1 file changed

+77
-41
lines changed

1 file changed

+77
-41
lines changed

tests/bench.py

Lines changed: 77 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
import torch
22
import triton
3+
from typing import Literal
34

45

5-
def init(B, C, T, device):
6+
def init(B, C, T, *, device, requires_grad=False):
67
torch.manual_seed(12312323)
7-
gates = 0.999 + 0.001 * torch.rand(B, C, T, device=device)
8+
gates = 0.999 + 0.001 * torch.rand(B, C, T, device=device, requires_grad=requires_grad)
89
gates = gates.half().float()
9-
tokens = torch.rand(B, C, T, device=device)
10+
tokens = torch.rand(B, C, T, device=device, requires_grad=requires_grad)
1011
return gates, tokens
1112

1213

13-
@triton.testing.perf_report([
14-
triton.testing.Benchmark(
14+
def make_benchmark(plot_name, *, direction, max_exponent=17):
15+
return triton.testing.Benchmark(
1516
x_names=["SEQUENCE_LENGTH"], # argument names to use as an x-axis for the plot
16-
x_vals=[2**i for i in range(7,17)],
17+
x_vals=[2**i for i in range(7, max_exponent)],
1718
xlabel='sequence length',
1819
ylabel='ms',
1920
x_log=True,
@@ -23,57 +24,92 @@ def init(B, C, T, device):
2324
#line_vals=["triton", "ref", "warp"],
2425
line_names=["warp"],
2526
line_vals=["warp"],
26-
plot_name="accelerated_scan: forward speed of (8,1536,seqlen), inference mode", # name of the plot
27-
args={}
28-
),
29-
triton.testing.Benchmark(
30-
x_names=["SEQUENCE_LENGTH"], # argument names to use as an x-axis for the plot
31-
x_vals=[2**i for i in range(7,17)],
32-
xlabel='sequence length',
33-
ylabel='ms',
34-
x_log=True,
35-
y_log=True,
36-
line_arg="provider", # argument name whose value corresponds to a different line in the plot
37-
#line_names=["triton", "ref", "warp"],
38-
#line_vals=["triton", "ref", "warp"],
39-
line_names=["warp"],
40-
line_vals=["warp"],
41-
plot_name="accelerated_scan: reverse speed of (8,1536,seqlen), inference mode", # name of the plot
27+
plot_name=plot_name,
4228
args={
43-
"reverse": True,
29+
"direction": direction,
4430
}
45-
),
46-
])
47-
@torch.inference_mode()
48-
def bench(provider, SEQUENCE_LENGTH, CHUNK_LENGTH=64, device="cuda", reverse=False):
31+
)
32+
33+
34+
def grad2(f, x, y, grad_out):
35+
grad = torch.autograd.grad(f(x, y), (x, y), grad_out)
36+
sum(x.sum().item() for x in grad)
37+
38+
39+
def bench(provider, SEQUENCE_LENGTH, device="cuda", direction: Literal["forward", "backward", "train"] = "forward"):
4940
B, C, T = 8, 1536, SEQUENCE_LENGTH
50-
gates, tokens = init(B, C, T, device)
41+
gates, tokens = init(B, C, T, device=device, requires_grad=direction=="train")
5142
outputs = torch.empty_like(tokens)
43+
grad_outputs = torch.empty_like(tokens)
5244

53-
direction = "reversed" if reverse else "forward"
5445
match provider:
5546
case "triton":
56-
print(f"Running {provider} with sequence length {SEQUENCE_LENGTH} {direction}")
57-
output_gates = torch.zeros_like(gates).contiguous()
58-
from accelerated_scan.triton import forward_scan, backward_scan
59-
if reverse:
60-
scan = lambda: backward_scan[(B,C)](gates, tokens, outputs, SEQUENCE_LENGTH, enable_fp_fusion=False)
61-
else:
62-
scan = lambda: forward_scan[(B,C)](gates, tokens, outputs, SEQUENCE_LENGTH, enable_fp_fusion=False)
47+
print(f"Running {direction} {provider} with sequence length {SEQUENCE_LENGTH}")
48+
match direction:
49+
case "forward":
50+
from accelerated_scan.triton import forward_scan
51+
scan = lambda: forward_scan[(B,C)](gates, tokens, outputs, SEQUENCE_LENGTH, enable_fp_fusion=False)
52+
case "backward":
53+
from accelerated_scan.triton import backward_scan
54+
scan = lambda: backward_scan[(B,C)](gates, tokens, outputs, SEQUENCE_LENGTH, enable_fp_fusion=False)
55+
case "train":
56+
# note that these measurements include time for memory allocation for forward output tensors
57+
from accelerated_scan.triton import scan as train_scan
58+
scan = lambda: grad2(train_scan, gates, tokens, grad_outputs)
6359
case "ref":
6460
print(f"Running {provider} with sequence length {SEQUENCE_LENGTH} {direction}")
6561
from accelerated_scan.ref import scan as scan_ref
66-
scan = lambda: scan_ref(gates, tokens, reverse=reverse)
62+
match direction:
63+
case "forward":
64+
scan = lambda: scan_ref(gates, tokens)
65+
case "backward":
66+
scan = lambda: scan_ref(gates, tokens, reverse=True)
67+
case "train":
68+
scan = lambda: grad2(scan_ref, gates, tokens, grad_outputs)
6769
case "warp":
6870
print(f"Running {provider} with sequence length {SEQUENCE_LENGTH} {direction}")
69-
from accelerated_scan.warp import warpscan_forward
70-
scan = lambda: warpscan_forward(gates, tokens, outputs, reverse)
71+
match direction:
72+
case "forward":
73+
from accelerated_scan.warp import warpscan_forward
74+
scan = lambda: warpscan_forward(gates, tokens, outputs, False)
75+
case "backward":
76+
from accelerated_scan.warp import warpscan_forward
77+
scan = lambda: warpscan_forward(gates, tokens, outputs, True)
78+
case "train":
79+
# note that these measurements include time for memory allocation for forward output tensors
80+
from accelerated_scan.warp import scan as train_scan
81+
scan = lambda: grad2(train_scan, gates, tokens, grad_outputs)
7182
case _:
7283
raise ValueError(f"Unknown provider {provider}")
7384

7485
# large warmup for benefit of torch.compile
75-
ms = triton.testing.do_bench(scan, warmup=5000, rep=100)
86+
if direction == "train":
87+
ms = triton.testing.do_bench(scan, warmup=5000, rep=100)
88+
else:
89+
with torch.inference_mode():
90+
ms = triton.testing.do_bench(scan, warmup=5000, rep=100)
7691
return ms
7792

93+
7894
if __name__ == '__main__':
79-
bench.run(save_path=".", print_data=True)
95+
import argparse
96+
parser = argparse.ArgumentParser()
97+
parser.add_argument("--direction", choices=["forward", "backward", "train", "all"], default="all")
98+
args = parser.parse_args()
99+
100+
directions = {
101+
'forward': make_benchmark("accelerated_scan: forward speed of (8,1536,seqlen), inference mode", direction="forward"),
102+
'backward': make_benchmark("accelerated_scan: backward speed of (8,1536,seqlen), inference mode", direction="backward"),
103+
'train': make_benchmark("accelerated_scan: training speed of (8,1536,seqlen)", direction="train", max_exponent=15),
104+
}
105+
106+
benchmarks = []
107+
match args.direction:
108+
case "all":
109+
benchmarks.append(directions['forward'])
110+
benchmarks.append(directions['backward'])
111+
benchmarks.append(directions['train'])
112+
case dir:
113+
benchmarks.append(directions[dir])
114+
115+
triton.testing.perf_report(benchmarks)(bench).run(save_path=".", print_data=True)

0 commit comments

Comments
 (0)