1+ from itertools import chain
12from pathlib import Path
23from copy import deepcopy
3- import matplotlib .pyplot as plt
44import triton .profiler as proton
5- from triton .profiler import viewer
65import torch
76import argparse
87import triton_kernels
98import triton_kernels .swiglu
109from triton_kernels .matmul_ogs import matmul_ogs , PrecisionConfig , FlexCtx , FnSpecs , FusedActivation
11- from triton_kernels .target_info import is_hip , get_cdna_version
12- from dataclasses import dataclass
10+ from triton_kernels .target_info import get_cdna_version
1311import distributed as triton_dist
1412from triton_kernels .tensor_details import layout
1513from bench_utils import quantize_weight
14+ import tempfile
15+ import roofline
1616
17- if torch .cuda .is_available () and not is_hip ():
18- from triton ._C .libtriton import nvidia
1917
20- cublas_workspace = torch .empty (32 * 1024 * 1024 , device = "cuda" , dtype = torch .uint8 )
21- cublas = nvidia .cublas .CublasLt (cublas_workspace )
22- else :
23- cublas = None
24-
25-
26- @dataclass
27- class PerfData :
28- time : float
29- flops : float
30- bytes : float
31- bitwidth : int
32- device_type : str
33- device_info : dict
34-
35- @property
36- def tflops (self ):
37- return self .flops / self .time * 1e-3
38-
39- @property
40- def tbps (self ):
41- return self .bytes / self .time * 1e-3
42-
43- @property
44- def opint (self ):
45- # operational intensity
46- assert self .bytes > 0
47- return self .flops / self .bytes
48-
49- @property
50- def max_tbps (self ):
51- return (proton .specs .max_bps (
52- self .device_type ,
53- self .device_info ["arch" ],
54- self .device_info ["bus_width" ],
55- self .device_info ["memory_clock_rate" ],
56- ) * 1e-12 )
57-
58- @property
59- def max_tflops (self ):
60- return (proton .specs .max_flops (
61- self .device_type ,
62- self .device_info ["arch" ],
63- self .bitwidth ,
64- self .device_info ["num_sms" ],
65- self .device_info ["clock_rate" ],
66- ) * 1e-12 )
67-
68- @property
69- def util (self ) -> float :
70- assert self .bitwidth in (8 , 16 )
71- min_t_flop = self .flops / self .max_tflops * 1e-3
72- min_t_bw = self .bytes / self .max_tbps * 1e-3
73- return max (min_t_flop , min_t_bw ) / self .time
74-
75-
76- def get_bench_path (name , rank , x_dtype , w_dtype , TP , EP ):
77- return Path (f"logs/{ name } /{ rank } /{ x_dtype } -{ w_dtype } -TP{ TP } -EP{ EP } /" )
78-
79-
80- def bench_mlp (batch , dim1 , dim2 , n_expts_tot , n_expts_act , x_dtype , w_dtype , TP , EP , name ):
18+ def bench_mlp (batch_per_expt , dim1 , dim2 , n_expts_tot , n_expts_act , x_dtype , w_dtype , TP , EP ):
8119 assert n_expts_tot % EP == 0
8220 assert dim2 % TP == 0
8321 rank , world_size = triton_dist .setup ()
8422 dev = f"cuda:{ rank } "
8523 DP = world_size
24+ batch = batch_per_expt * n_expts_tot // n_expts_act
8625
8726 assert n_expts_tot % EP == 0 , f"{ n_expts_tot = } , { EP = } , n_expts_tot must be divisible by EP"
8827 assert dim2 % TP == 0 , f"{ dim2 = } , { TP = } , dim2 must be divisible by TP"
8928
90- # input
29+ # -- init data --
9130 # weights
9231 wg = triton_dist .broadcast (torch .randn ((dim1 , n_expts_tot ), device = dev ))
9332 w1 = torch .randn ((n_expts_tot // EP , dim1 , dim2 // TP ), device = dev )
9433 w2 = torch .randn ((n_expts_tot // EP , dim2 // TP // 2 , dim1 ), device = dev )
95-
9634 # biases
9735 bg = triton_dist .broadcast (torch .randn ((n_expts_tot , ), device = dev ))
9836 b1 = torch .randn ((n_expts_tot // EP , dim2 // TP ), device = dev )
@@ -125,16 +63,15 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP,
12563 pc2 = PrecisionConfig (flex_ctx = FlexCtx (rhs_data = w2_flex ), weight_scale = w2_scale )
12664
12765 # -- benchmark --
128- fpath = get_bench_path (name , rank , x_dtype , w_dtype , TP , EP ) / f"profiles/batch-{ batch } .hatchet"
129- fpath .parent .mkdir (parents = True , exist_ok = True )
13066 x_dtype = {"fp16" : torch .float16 , "bf16" : torch .bfloat16 , "fp8" : torch .float8_e4m3fn }[x_dtype ]
13167 # special treatment of fp8_e4m3 on AMD CDNA3 because it uses fp8_e4m3fnuz
13268 if x_dtype == torch .float8_e4m3fn and get_cdna_version () == 3 :
13369 x_dtype = torch .float8_e4m3fnuz
13470
13571 input_x = torch .randn ((batch // DP , dim1 ), device = dev )
13672 # run layer
137- proton .start (str (fpath .with_suffix ("" )), hook = "triton" )
73+ fpath = Path (tempfile .mktemp ())
74+ proton .start (str (fpath ), hook = "triton" )
13875 input_x = input_x .to (x_dtype )
13976 xg = input_x .to (wg .dtype if n_expts_tot > 1 else input_x .dtype )
14077 for i in range (100 ):
@@ -151,114 +88,66 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP,
15188 precision_config = pc2 )
15289 x = triton_dist .reduce_scatter (x , metadata = metadata , dim = 0 )
15390 proton .finalize ()
154-
155- # -- analyze --
156- gf , _ , _ , info = viewer .read (fpath )
157- # Now the dataframe only contains leave nodes (i.e., kernels) that perform matmuls
158- matmuls = gf .filter ("MATCH ('*', c) WHERE c.'name' =~ '.*matmul.*' AND c IS LEAF" ).dataframe
159- bytes = matmuls ["bytes" ].sum ()
160- flops = sum (matmuls [[c for c in ["flops8" , "flops16" ] if c in matmuls .columns ]].sum ())
161- # Compute total time (incl. "not useful" work)
162- time = gf .filter ("MATCH ('*', c) WHERE c IS LEAF" ).dataframe ["time (ns)" ].sum ()
163- device_type = matmuls ["device_type" ].iloc [0 ]
164- device_id = matmuls ["device_id" ].iloc [0 ]
165- device_info = info [device_type ][device_id ]
166- return PerfData (
167- time = time ,
168- flops = flops ,
169- bytes = bytes ,
170- bitwidth = x .dtype .itemsize * 8 ,
171- device_type = device_type ,
172- device_info = device_info ,
173- )
91+ return roofline .parse_profile (fpath .with_suffix (".hatchet" ), useful_op_regex = ".*matmul.*" )
17492
17593
176- def roofline_mlp (batch_ranges , dim1 , dim2 , n_expts_tot , n_expts_act , x_dtype , w_dtype , TP = 1 , EP = 1 , name = "" ,
177- verbose = True ):
178- from itertools import chain
179- from bisect import bisect_left
94+ def roofline_mlp (batch_sizes , dim1 , dim2 , n_expts_tot , n_expts_act , x_dtype , w_dtype , TP , EP , \
95+ name = "" , verbose = True ):
96+ out_path = Path (f"logs/{ name } /{ x_dtype } x-{ w_dtype } w-TP{ TP } -EP{ EP } /" )
97+ out_path .mkdir (parents = True , exist_ok = True )
98+ csv_path = roofline .compute_roofline (dim1 , dim2 , n_expts_tot , n_expts_act , x_dtype , w_dtype , TP , EP , # fixed args
99+ bench_fn = bench_mlp , # function to benchmark
100+ intensity_proxy_name = "batch_per_expt" , # intensity proxy name
101+ intensity_proxy_values = batch_sizes , # intensity proxy values to sweep
102+ verbose = verbose , # options
103+ out_path = out_path .with_suffix (".csv" )) # output path
104+ png_path = roofline .plot_roofline (series = [csv_path ], # roofline data to plot
105+ flops_dtype = x_dtype , # dtype to use for FLOPS roof
106+ xlabel = "batch_per_expt" , title = out_path , # plot option
107+ out_path = out_path .with_suffix (".png" ), # output path
108+ max_tbps = "memset" , max_tflops = "cublas" ) # hardware limits
180109
181- batches = list (chain (* [range (* r ) for r in batch_ranges ]))
182- # collect performance data
183- perfs = []
184- bench_case = f"{ name } ({ x_dtype } x{ w_dtype } , TP={ TP } , EP={ EP } )"
185- print (f"Benchmarking { bench_case } ..." )
186- print ("===============================================================" )
187- for batch in batches :
188- perfs += [bench_mlp (batch , dim1 , dim2 , n_expts_tot , n_expts_act , x_dtype , w_dtype , TP , EP , name )]
189- if verbose :
190- print (f"Batch: { batch } ; Util: { perfs [- 1 ].util } ; TFLOPS: { perfs [- 1 ].tflops } ; TBPS: { perfs [- 1 ].tbps } " )
191- print ("===============================================================" )
192- # machine limits
193- max_tbps = perfs [0 ].max_tbps
194- max_tflops = perfs [0 ].max_tflops
195- fig , ax = plt .subplots (figsize = (7 , 5 ), dpi = 120 )
196- ax .set_xlabel ("batch size (toks/expt)" )
197- ax .set_ylabel ("performance [TFLOP/s]" )
198- ax .set_title (f"{ bench_case } roofline" )
199- # add a tiny margin so points are not flush with the frame
200- xs = [batch * n_expts_act / n_expts_tot for batch in batches ]
201- perf = [p .tflops for p in perfs ]
202- xmin , xmax = min (xs ), max (xs )
203- dx = 0.05 * (xmax - xmin ) if xmax > xmin else 1.0
204- ax .set_xlim (xmin - dx , xmax + dx )
205- ax .set_ylim (100 , max_tflops + 500 )
206- # plot roofline
207- opints = [p .opint for p in perfs ]
208- knee = bisect_left (opints , max_tflops / max_tbps )
209- if knee > 0 : # has a bandwidth-bound knee
210- x_bw = [xs [0 ], xs [knee - 1 ]]
211- y_bw = [opints [0 ] * max_tbps , max_tflops ]
212- else : # no knee found, compute-bound only
213- x_bw = y_bw = []
214- x_comp = xs [knee :]
215- y_comp = [max_tflops ] * len (x_comp )
216- ax .plot (x_bw , y_bw , "--" , label = f"BW-bound ({ max_tbps :.1f} TB/s)" , color = "blue" )
217- ax .plot (x_comp , y_comp , "--" , label = f"Compute-bound ({ max_tflops :.0f} TFLOP/s)" , color = "orange" )
218- # plot data
219- ax .scatter (xs , perf , marker = "+" )
220- ax .legend (frameon = False , loc = "lower right" )
221- ax .grid (True , which = "both" , ls = ":" , lw = 0.5 )
222- fig .tight_layout ()
223- rank , _ = triton_dist .setup ()
224- fpath = get_bench_path (name , rank , x_dtype , w_dtype , TP , EP ) / "roofline.png"
225- plt .savefig (fpath )
110+ return png_path
226111
227112
228113if __name__ == "__main__" :
229114 has_native_mx4 = torch .cuda .get_device_capability (0 )[0 ] >= 10 or get_cdna_version () == 4
230- batch_ranges_dense = [(1024 , 32768 , 1024 )]
231- batch_ranges_moe = [(128 , 512 , 32 ), (512 , 32000 , 128 )]
115+ batch_sizes_dense = [(128 , 8192 , 128 )]
116+ batch_ranges_moe = [(2 ** (2 + k ), 2 ** (3 + k ), min (2 ** k , 32 )) for k in range (8 )]
117+ batch_sizes_moe = list (chain (* [range (* r ) for r in batch_ranges_moe ]))
232118 dense_dtypes = ["fp8" , "fp8" ]
233119 quantized_dtypes = ["fp8" , "mx4" ] if has_native_mx4 else ["bf16" , "mx4" ]
234120 rank , world_size = triton_dist .setup ()
235121 if world_size > 1 :
236122 # Running all workloads at once may cause OOM on some GPUs such as H100 80GB.
237123 # Thus we request users to run each workload separately.
238124 # For example, all eligible combinations of options are listed below when four GPUs are used:
239- # torchrun --nproc-per-node=4 ./bench_mlp.py --tp 2 --ep 2 --name llama4-maverick
240- # torchrun --nproc-per-node=4 ./bench_mlp.py --tp 1 --ep 4 --name llama4-maverick
241- # torchrun --nproc-per-node=4 ./bench_mlp.py --tp 4 --ep 1 --name llama4-maverick
125+ # torchrun --nproc-per-node=4 ./bench_mlp.py --tp 2 --ep 2 --name gpt-oss-x2
126+ # torchrun --nproc-per-node=4 ./bench_mlp.py --tp 1 --ep 4 --name gpt-oss-x2
127+ # torchrun --nproc-per-node=4 ./bench_mlp.py --tp 4 --ep 1 --name gpt-oss-x2
242128 # torchrun --nproc-per-node=4 ./bench_mlp.py --tp 4 --ep 1 --name dense
243- # torchrun --nproc-per-node=4 ./bench_mlp.py --tp 2 --ep 2 --name llama4-maverick --quantized
244- # torchrun --nproc-per-node=4 ./bench_mlp.py --tp 1 --ep 4 --name llama4-maverick --quantized
245- # torchrun --nproc-per-node=4 ./bench_mlp.py --tp 4 --ep 1 --name llama4-maverick --quantized
129+ # torchrun --nproc-per-node=4 ./bench_mlp.py --tp 2 --ep 2 --name gpt-oss-x2 --quantized
130+ # torchrun --nproc-per-node=4 ./bench_mlp.py --tp 1 --ep 4 --name gpt-oss-x2 --quantized
131+ # torchrun --nproc-per-node=4 ./bench_mlp.py --tp 4 --ep 1 --name gpt-oss-x2 --quantized
246132 # torchrun --nproc-per-node=4 ./bench_mlp.py --tp 4 --ep 1 --name dense --quantized
247133 argparse = argparse .ArgumentParser ()
248134 argparse .add_argument ("--tp" , type = int , default = 1 )
249135 argparse .add_argument ("--ep" , type = int , default = 1 )
250- argparse .add_argument ("--name" , type = str , choices = ["dense" , "llama4-maverick " ])
136+ argparse .add_argument ("--name" , type = str , choices = ["dense" , "gpt-oss-x2 " ])
251137 argparse .add_argument ("--quantized" , action = "store_true" , default = False )
252138 args = argparse .parse_args ()
253139 dtypes = dense_dtypes if args .quantized else quantized_dtypes
254140 if args .name == "dense" :
255141 assert args .ep == 1 , "EP must be 1 for dense"
256- roofline_mlp (batch_ranges_dense , 8192 , 8192 , 1 , 1 , * dtypes , TP = args .tp , EP = args .ep , name = "dense" )
142+ roofline_mlp (batch_sizes_dense , 8192 , 8192 , 1 , 1 , * dtypes , TP = args .tp , EP = args .ep , name = "dense" )
257143 else :
258- roofline_mlp (batch_ranges_moe , 5120 , 8192 , 128 , 4 , * dtypes , TP = args .tp , EP = args .ep , name = "llama4-maverick " )
144+ roofline_mlp (batch_sizes_moe , 5760 , 5760 , 128 , 4 , * dtypes , TP = args .tp , EP = args .ep , name = "gpt-oss-x2 " )
259145 triton_dist .cleanup ()
260146 else :
261- roofline_mlp (batch_ranges_dense , 8192 , 8192 , 1 , 1 , * dense_dtypes , TP = 1 , EP = 1 , name = "dense" )
262- roofline_mlp (batch_ranges_dense , 8192 , 8192 , 1 , 1 , * quantized_dtypes , TP = 1 , EP = 1 , name = "dense" )
263- roofline_mlp (batch_ranges_moe , 5120 , 8192 , 128 , 4 , * dense_dtypes , TP = 1 , EP = 1 , name = "llama4-maverick" )
264- roofline_mlp (batch_ranges_moe , 5120 , 8192 , 128 , 4 , * quantized_dtypes , TP = 1 , EP = 1 , name = "llama4-maverick" )
147+ pass
148+ # roofline_mlp(batch_sizes_dense, 8192, 8192, 1, 1, *quantized_dtypes, TP=1, EP=1, name="dense")
149+ # roofline_mlp(batch_sizes_moe, 5760, 5760, 128, 4, *dense_dtypes, TP=1, EP=1, name="gpt-oss-x2")
150+ roofline_mlp (batch_sizes_moe , 5760 , 5760 , 128 , 4 , * quantized_dtypes , TP = 1 , EP = 1 , name = "gpt-oss-x2" )
151+ # roofline_mlp(batch_sizes_moe, 5760, 5760, 128, 4, *quantized_dtypes, TP=2, EP=1, name="gpt-oss-x2")
152+ # roofline_mlp(batch_sizes_moe, 5760, 5760, 128, 4, *quantized_dtypes, TP=4, EP=1, name="gpt-oss-x2")
153+ # roofline_mlp(batch_ranges_moe, 5760, 5760, 128, 4, *quantized_dtypes, TP=8, EP=1, name="gpt-oss-x2")
0 commit comments