11import torch
22import triton
3+ from typing import Literal
34
45
5- def init (B , C , T , device ):
6+ def init (B , C , T , * , device , requires_grad = False ):
67 torch .manual_seed (12312323 )
7- gates = 0.999 + 0.001 * torch .rand (B , C , T , device = device )
8+ gates = 0.999 + 0.001 * torch .rand (B , C , T , device = device , requires_grad = requires_grad )
89 gates = gates .half ().float ()
9- tokens = torch .rand (B , C , T , device = device )
10+ tokens = torch .rand (B , C , T , device = device , requires_grad = requires_grad )
1011 return gates , tokens
1112
1213
13- @ triton . testing . perf_report ([
14- triton .testing .Benchmark (
14+ def make_benchmark ( plot_name , * , direction , max_exponent = 17 ):
15+ return triton .testing .Benchmark (
1516 x_names = ["SEQUENCE_LENGTH" ], # argument names to use as an x-axis for the plot
16- x_vals = [2 ** i for i in range (7 ,17 )],
17+ x_vals = [2 ** i for i in range (7 , max_exponent )],
1718 xlabel = 'sequence length' ,
1819 ylabel = 'ms' ,
1920 x_log = True ,
@@ -23,57 +24,92 @@ def init(B, C, T, device):
2324 #line_vals=["triton", "ref", "warp"],
2425 line_names = ["warp" ],
2526 line_vals = ["warp" ],
26- plot_name = "accelerated_scan: forward speed of (8,1536,seqlen), inference mode" , # name of the plot
27- args = {}
28- ),
29- triton .testing .Benchmark (
30- x_names = ["SEQUENCE_LENGTH" ], # argument names to use as an x-axis for the plot
31- x_vals = [2 ** i for i in range (7 ,17 )],
32- xlabel = 'sequence length' ,
33- ylabel = 'ms' ,
34- x_log = True ,
35- y_log = True ,
36- line_arg = "provider" , # argument name whose value corresponds to a different line in the plot
37- #line_names=["triton", "ref", "warp"],
38- #line_vals=["triton", "ref", "warp"],
39- line_names = ["warp" ],
40- line_vals = ["warp" ],
41- plot_name = "accelerated_scan: reverse speed of (8,1536,seqlen), inference mode" , # name of the plot
27+ plot_name = plot_name ,
4228 args = {
43- "reverse " : True ,
29+ "direction " : direction ,
4430 }
45- ),
46- ])
47- @torch .inference_mode ()
48- def bench (provider , SEQUENCE_LENGTH , CHUNK_LENGTH = 64 , device = "cuda" , reverse = False ):
31+ )
32+
33+
34+ def grad2 (f , x , y , grad_out ):
35+ grad = torch .autograd .grad (f (x , y ), (x , y ), grad_out )
36+ sum (x .sum ().item () for x in grad )
37+
38+
39+ def bench (provider , SEQUENCE_LENGTH , device = "cuda" , direction : Literal ["forward" , "backward" , "train" ] = "forward" ):
4940 B , C , T = 8 , 1536 , SEQUENCE_LENGTH
50- gates , tokens = init (B , C , T , device )
41+ gates , tokens = init (B , C , T , device = device , requires_grad = direction == "train" )
5142 outputs = torch .empty_like (tokens )
43+ grad_outputs = torch .empty_like (tokens )
5244
53- direction = "reversed" if reverse else "forward"
5445 match provider :
5546 case "triton" :
56- print (f"Running { provider } with sequence length { SEQUENCE_LENGTH } { direction } " )
57- output_gates = torch .zeros_like (gates ).contiguous ()
58- from accelerated_scan .triton import forward_scan , backward_scan
59- if reverse :
60- scan = lambda : backward_scan [(B ,C )](gates , tokens , outputs , SEQUENCE_LENGTH , enable_fp_fusion = False )
61- else :
62- scan = lambda : forward_scan [(B ,C )](gates , tokens , outputs , SEQUENCE_LENGTH , enable_fp_fusion = False )
47+ print (f"Running { direction } { provider } with sequence length { SEQUENCE_LENGTH } " )
48+ match direction :
49+ case "forward" :
50+ from accelerated_scan .triton import forward_scan
51+ scan = lambda : forward_scan [(B ,C )](gates , tokens , outputs , SEQUENCE_LENGTH , enable_fp_fusion = False )
52+ case "backward" :
53+ from accelerated_scan .triton import backward_scan
54+ scan = lambda : backward_scan [(B ,C )](gates , tokens , outputs , SEQUENCE_LENGTH , enable_fp_fusion = False )
55+ case "train" :
56+ # note that these measurements include time for memory allocation for forward output tensors
57+ from accelerated_scan .triton import scan as train_scan
58+ scan = lambda : grad2 (train_scan , gates , tokens , grad_outputs )
6359 case "ref" :
6460 print (f"Running { provider } with sequence length { SEQUENCE_LENGTH } { direction } " )
6561 from accelerated_scan .ref import scan as scan_ref
66- scan = lambda : scan_ref (gates , tokens , reverse = reverse )
62+ match direction :
63+ case "forward" :
64+ scan = lambda : scan_ref (gates , tokens )
65+ case "backward" :
66+ scan = lambda : scan_ref (gates , tokens , reverse = True )
67+ case "train" :
68+ scan = lambda : grad2 (scan_ref , gates , tokens , grad_outputs )
6769 case "warp" :
6870 print (f"Running { provider } with sequence length { SEQUENCE_LENGTH } { direction } " )
69- from accelerated_scan .warp import warpscan_forward
70- scan = lambda : warpscan_forward (gates , tokens , outputs , reverse )
71+ match direction :
72+ case "forward" :
73+ from accelerated_scan .warp import warpscan_forward
74+ scan = lambda : warpscan_forward (gates , tokens , outputs , False )
75+ case "backward" :
76+ from accelerated_scan .warp import warpscan_forward
77+ scan = lambda : warpscan_forward (gates , tokens , outputs , True )
78+ case "train" :
79+ # note that these measurements include time for memory allocation for forward output tensors
80+ from accelerated_scan .warp import scan as train_scan
81+ scan = lambda : grad2 (train_scan , gates , tokens , grad_outputs )
7182 case _:
7283 raise ValueError (f"Unknown provider { provider } " )
7384
7485 # large warmup for benefit of torch.compile
75- ms = triton .testing .do_bench (scan , warmup = 5000 , rep = 100 )
86+ if direction == "train" :
87+ ms = triton .testing .do_bench (scan , warmup = 5000 , rep = 100 )
88+ else :
89+ with torch .inference_mode ():
90+ ms = triton .testing .do_bench (scan , warmup = 5000 , rep = 100 )
7691 return ms
7792
93+
7894if __name__ == '__main__' :
79- bench .run (save_path = "." , print_data = True )
95+ import argparse
96+ parser = argparse .ArgumentParser ()
97+ parser .add_argument ("--direction" , choices = ["forward" , "backward" , "train" , "all" ], default = "all" )
98+ args = parser .parse_args ()
99+
100+ directions = {
101+ 'forward' : make_benchmark ("accelerated_scan: forward speed of (8,1536,seqlen), inference mode" , direction = "forward" ),
102+ 'backward' : make_benchmark ("accelerated_scan: backward speed of (8,1536,seqlen), inference mode" , direction = "backward" ),
103+ 'train' : make_benchmark ("accelerated_scan: training speed of (8,1536,seqlen)" , direction = "train" , max_exponent = 15 ),
104+ }
105+
106+ benchmarks = []
107+ match args .direction :
108+ case "all" :
109+ benchmarks .append (directions ['forward' ])
110+ benchmarks .append (directions ['backward' ])
111+ benchmarks .append (directions ['train' ])
112+ case dir :
113+ benchmarks .append (directions [dir ])
114+
115+ triton .testing .perf_report (benchmarks )(bench ).run (save_path = "." , print_data = True )
0 commit comments