4
4
import triton .profiler as proton
5
5
from triton .profiler import viewer
6
6
import torch
7
+ import argparse
7
8
import triton_kernels
8
9
import triton_kernels .swiglu
9
- from triton_kernels .numerics_details .mxfp import downcast_to_mxfp
10
10
from triton_kernels .matmul_ogs import matmul_ogs , PrecisionConfig , FlexCtx , FnSpecs , FusedActivation
11
- from triton_kernels .numerics import InFlexData
12
- from triton_kernels .routing import routing
13
- from triton_kernels .target_info import is_cuda , is_hip , get_cdna_version , cuda_capability_geq
14
- from triton_kernels .tensor import convert_layout
15
- from triton_kernels .tensor import wrap_torch_tensor , FP4
11
+ from triton_kernels .target_info import is_hip , get_cdna_version
16
12
from dataclasses import dataclass
13
+ import distributed as triton_dist
17
14
from triton_kernels .tensor_details import layout
15
+ from bench_utils import quantize_weight
18
16
19
17
if torch .cuda .is_available () and not is_hip ():
20
18
from triton ._C .libtriton import nvidia
19
+
21
20
cublas_workspace = torch .empty (32 * 1024 * 1024 , device = "cuda" , dtype = torch .uint8 )
22
21
cublas = nvidia .cublas .CublasLt (cublas_workspace )
23
22
else :
24
23
cublas = None
25
24
26
25
27
- def quantize (w , dtype , ** opt ):
28
- if dtype == "bf16" :
29
- wq = w .to (torch .bfloat16 ).transpose (- 1 , - 2 ).contiguous ().transpose (- 1 , - 2 )
30
- return wq , InFlexData (), None
31
- elif dtype == "fp8" :
32
- fp8e4_dtype = torch .float8_e4m3fn if get_cdna_version () != 3 \
33
- else torch .float8_e4m3fnuz
34
- wq = w .to (fp8e4_dtype )
35
- if is_cuda () and not cuda_capability_geq (10 , 0 ):
36
- wq = wq .transpose (- 1 , - 2 ).contiguous ().transpose (- 1 , - 2 )
37
- return wq , InFlexData (dtype = wq .dtype , scale = w .abs ().max ().unsqueeze (0 )), None
38
- else :
39
- assert dtype == "mx4" , f"{ dtype = } "
40
- w , w_scale = downcast_to_mxfp (w .to (torch .bfloat16 ), torch .uint8 , axis = 1 )
41
- if opt :
42
- w = convert_layout (wrap_torch_tensor (w , dtype = FP4 ), opt ["value_layout" ], ** opt ["value_layout_opts" ])
43
- w_scale = convert_layout (wrap_torch_tensor (w_scale ), opt ["scale_layout" ], ** opt ["scale_layout_opts" ])
44
- return w , InFlexData (), w_scale
45
-
46
-
47
26
@dataclass
48
27
class PerfData :
49
28
time : float
@@ -69,13 +48,22 @@ def opint(self):
69
48
70
49
@property
71
50
def max_tbps (self ):
72
- return proton .specs .max_bps (self .device_type , self .device_info ["arch" ], self .device_info ["bus_width" ],
73
- self .device_info ["memory_clock_rate" ]) * 1e-12
51
+ return (proton .specs .max_bps (
52
+ self .device_type ,
53
+ self .device_info ["arch" ],
54
+ self .device_info ["bus_width" ],
55
+ self .device_info ["memory_clock_rate" ],
56
+ ) * 1e-12 )
74
57
75
58
@property
76
59
def max_tflops (self ):
77
- return proton .specs .max_flops (self .device_type , self .device_info ["arch" ], self .bitwidth ,
78
- self .device_info ["num_sms" ], self .device_info ["clock_rate" ]) * 1e-12
60
+ return (proton .specs .max_flops (
61
+ self .device_type ,
62
+ self .device_info ["arch" ],
63
+ self .bitwidth ,
64
+ self .device_info ["num_sms" ],
65
+ self .device_info ["clock_rate" ],
66
+ ) * 1e-12 )
79
67
80
68
@property
81
69
def util (self ) -> float :
@@ -85,62 +73,83 @@ def util(self) -> float:
85
73
return max (min_t_flop , min_t_bw ) / self .time
86
74
87
75
76
+ def get_bench_path (name , rank , x_dtype , w_dtype , TP , EP ):
77
+ return Path (f"logs/{ name } /{ rank } /{ x_dtype } -{ w_dtype } -TP{ TP } -EP{ EP } /" )
78
+
79
+
88
80
def bench_mlp (batch , dim1 , dim2 , n_expts_tot , n_expts_act , x_dtype , w_dtype , TP , EP , name ):
89
81
assert n_expts_tot % EP == 0
90
82
assert dim2 % TP == 0
91
- dev = "cuda"
83
+ rank , world_size = triton_dist .setup ()
84
+ dev = f"cuda:{ rank } "
85
+ DP = world_size
86
+
87
+ assert n_expts_tot % EP == 0 , f"{ n_expts_tot = } , { EP = } , n_expts_tot must be divisible by EP"
88
+ assert dim2 % TP == 0 , f"{ dim2 = } , { TP = } , dim2 must be divisible by TP"
92
89
93
90
# input
94
91
# weights
95
- wg = torch .randn ((dim1 , n_expts_tot ), device = dev )
92
+ wg = triton_dist . broadcast ( torch .randn ((dim1 , n_expts_tot ), device = dev ) )
96
93
w1 = torch .randn ((n_expts_tot // EP , dim1 , dim2 // TP ), device = dev )
97
94
w2 = torch .randn ((n_expts_tot // EP , dim2 // TP // 2 , dim1 ), device = dev )
95
+
98
96
# biases
99
- bg = torch .randn ((n_expts_tot , ), device = dev )
97
+ bg = triton_dist . broadcast ( torch .randn ((n_expts_tot , ), device = dev ) )
100
98
b1 = torch .randn ((n_expts_tot // EP , dim2 // TP ), device = dev )
101
99
b2 = torch .randn ((n_expts_tot // EP , dim1 ), device = dev )
100
+ ep_indx = (rank // TP ) % EP
101
+ groups = [list (range (ep * TP , (ep + 1 ) * TP )) for ep in range (EP )]
102
+ b2 = triton_dist .broadcast (b2 , src = ep_indx * TP , groups = groups , group_idx = ep_indx )
102
103
103
104
# -- numerics --
104
- optg = dict ()
105
105
opt1 = dict ()
106
106
opt2 = dict ()
107
107
if w_dtype == "mx4" and not is_hip ():
108
108
num_warps = 4 if batch <= 512 else 8
109
109
value_layout , value_layout_opts = layout .make_default_matmul_mxfp4_w_layout (mx_axis = 1 )
110
110
scale_layout , scale_layout_opts = layout .make_default_matmul_mxfp4_w_scale_layout (
111
111
mx_axis = 1 , num_warps = num_warps )
112
- opt1 = {"value_layout" : value_layout , "value_layout_opts" : value_layout_opts , \
113
- "scale_layout" : scale_layout , "scale_layout_opts" : scale_layout_opts }
112
+ opt1 = {
113
+ "value_layout" : value_layout ,
114
+ "value_layout_opts" : value_layout_opts ,
115
+ "scale_layout" : scale_layout ,
116
+ "scale_layout_opts" : scale_layout_opts ,
117
+ }
114
118
opt2 = deepcopy (opt1 )
115
- wg , wg_flex , wg_scale = quantize (wg , "bf16" , ** optg )
116
- w1 , w1_flex , w1_scale = quantize (w1 , w_dtype , ** opt1 )
117
- w2 , w2_flex , w2_scale = quantize (w2 , w_dtype , ** opt2 )
119
+ wg , wg_flex , wg_scale = quantize_weight (wg , "bf16" )
120
+ w1 , w1_flex , w1_scale = quantize_weight (w1 , w_dtype , ** opt1 )
121
+ w2 , w2_flex , w2_scale = quantize_weight (w2 , w_dtype , ** opt2 )
118
122
pcg = PrecisionConfig (flex_ctx = FlexCtx (rhs_data = wg_flex ), weight_scale = wg_scale )
119
123
act = FusedActivation (FnSpecs ("swiglu" , triton_kernels .swiglu .swiglu_fn , ("alpha" , "limit" )), (1.0 , 1.0 ), 2 )
120
124
pc1 = PrecisionConfig (flex_ctx = FlexCtx (rhs_data = w1_flex ), weight_scale = w1_scale )
121
125
pc2 = PrecisionConfig (flex_ctx = FlexCtx (rhs_data = w2_flex ), weight_scale = w2_scale )
122
126
123
127
# -- benchmark --
124
- fpath = Path ( f"logs/ { name } / { x_dtype } - { w_dtype } -TP { TP } -EP { EP } / profiles/batch-{ batch } .hatchet")
128
+ fpath = get_bench_path ( name , rank , x_dtype , w_dtype , TP , EP ) / f" profiles/batch-{ batch } .hatchet"
125
129
fpath .parent .mkdir (parents = True , exist_ok = True )
126
130
x_dtype = {"fp16" : torch .float16 , "bf16" : torch .bfloat16 , "fp8" : torch .float8_e4m3fn }[x_dtype ]
127
131
# special treatment of fp8_e4m3 on AMD CDNA3 because it uses fp8_e4m3fnuz
128
132
if x_dtype == torch .float8_e4m3fn and get_cdna_version () == 3 :
129
133
x_dtype = torch .float8_e4m3fnuz
130
134
131
- x = torch .randn ((batch , dim1 ), device = dev )
132
- xg = x .to (wg .dtype if n_expts_tot > 1 else x_dtype )
133
- x = x .to (x_dtype )
135
+ input_x = torch .randn ((batch // DP , dim1 ), device = dev )
134
136
# run layer
135
- proton .start (str (fpath .with_suffix ('' )), hook = "triton" )
137
+ proton .start (str (fpath .with_suffix ("" )), hook = "triton" )
138
+ input_x = input_x .to (x_dtype )
139
+ xg = input_x .to (wg .dtype if n_expts_tot > 1 else input_x .dtype )
136
140
for i in range (100 ):
137
- if n_expts_tot > 1 :
141
+ if n_expts_tot > 1 : # sparse
138
142
logits = matmul_ogs (xg , wg , bg , precision_config = pcg )
139
- rdata , gather_indx , scatter_indx = routing (logits , n_expts_act , simulated_ep = EP )
140
- else :
141
- rdata , gather_indx , scatter_indx = None , None , None
142
- x = matmul_ogs (x , w1 , b1 , rdata , gather_indx = gather_indx , precision_config = pc1 , fused_activation = act )
143
- x = matmul_ogs (x , w2 , b2 , rdata , scatter_indx = scatter_indx , precision_config = pc2 )
143
+ x , rdata , gather_indx , scatter_indx , metadata = triton_dist .routing (input_x , logits , n_expts_act , EP = EP ,
144
+ TP = TP )
145
+ else : # dense
146
+ x = triton_dist .all_gather (input_x , dim = 0 )
147
+ rdata , gather_indx , scatter_indx , metadata = None , None , None , None
148
+ if x .nelement () > 0 :
149
+ x = matmul_ogs (x , w1 , b1 , rdata , gather_indx = gather_indx , precision_config = pc1 , fused_activation = act )
150
+ x = matmul_ogs (x , w2 , b2 if rank % TP == 0 else None , rdata , scatter_indx = scatter_indx ,
151
+ precision_config = pc2 )
152
+ x = triton_dist .reduce_scatter (x , metadata = metadata , dim = 0 )
144
153
proton .finalize ()
145
154
146
155
# -- analyze --
@@ -153,14 +162,21 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP,
153
162
device_type = matmuls ["device_type" ].iloc [0 ]
154
163
device_id = matmuls ["device_id" ].iloc [0 ]
155
164
device_info = info [device_type ][device_id ]
156
- return PerfData (time = time , flops = flops , bytes = bytes , bitwidth = x .dtype .itemsize * 8 , device_type = device_type ,
157
- device_info = device_info )
165
+ return PerfData (
166
+ time = time ,
167
+ flops = flops ,
168
+ bytes = bytes ,
169
+ bitwidth = x .dtype .itemsize * 8 ,
170
+ device_type = device_type ,
171
+ device_info = device_info ,
172
+ )
158
173
159
174
160
175
def roofline_mlp (batch_ranges , dim1 , dim2 , n_expts_tot , n_expts_act , x_dtype , w_dtype , TP = 1 , EP = 1 , name = "" ,
161
176
verbose = True ):
162
177
from itertools import chain
163
178
from bisect import bisect_left
179
+
164
180
batches = list (chain (* [range (* r ) for r in batch_ranges ]))
165
181
# collect performance data
166
182
perfs = []
@@ -198,18 +214,13 @@ def roofline_mlp(batch_ranges, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_
198
214
y_comp = [max_tflops ] * len (x_comp )
199
215
ax .plot (x_bw , y_bw , "--" , label = f"BW-bound ({ max_tbps :.1f} TB/s)" , color = "blue" )
200
216
ax .plot (x_comp , y_comp , "--" , label = f"Compute-bound ({ max_tflops :.0f} TFLOP/s)" , color = "orange" )
201
- x_bw , x_comp = xs [:knee ], xs [knee :]
202
- x_bw = [x_bw [0 ], x_comp [0 ]]
203
- y_bw = [opints [0 ] * max_tbps , max_tflops ]
204
- y_comp = [max_tflops ] * len (x_comp )
205
- ax .plot (x_bw , y_bw , "--" , label = f"BW-bound ({ max_tbps :.1f} TB/s)" )
206
- ax .plot (x_comp , y_comp , "--" , label = f"Compute-bound ({ max_tflops :.0f} TFLOP/s)" )
207
217
# plot data
208
218
ax .scatter (xs , perf , marker = "+" )
209
219
ax .legend (frameon = False , loc = "lower right" )
210
220
ax .grid (True , which = "both" , ls = ":" , lw = 0.5 )
211
221
fig .tight_layout ()
212
- fpath = Path (f"logs/{ name } /{ x_dtype } -{ w_dtype } -TP{ TP } -EP{ EP } /roofline.png" )
222
+ rank , _ = triton_dist .setup ()
223
+ fpath = get_bench_path (name , rank , x_dtype , w_dtype , TP , EP ) / "roofline.png"
213
224
plt .savefig (fpath )
214
225
215
226
@@ -219,7 +230,34 @@ def roofline_mlp(batch_ranges, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_
219
230
batch_ranges_moe = [(128 , 512 , 32 ), (512 , 32000 , 128 )]
220
231
dense_dtypes = ["fp8" , "fp8" ]
221
232
quantized_dtypes = ["fp8" , "mx4" ] if has_native_mx4 else ["bf16" , "mx4" ]
222
- roofline_mlp (batch_ranges_dense , 8192 , 8192 , 1 , 1 , * dense_dtypes , TP = 1 , EP = 1 , name = "dense" )
223
- roofline_mlp (batch_ranges_dense , 8192 , 8192 , 1 , 1 , * quantized_dtypes , TP = 1 , EP = 1 , name = "dense" )
224
- roofline_mlp (batch_ranges_moe , 5120 , 8192 , 128 , 4 , * dense_dtypes , TP = 1 , EP = 1 , name = "llama4-maverick" )
225
- roofline_mlp (batch_ranges_moe , 5120 , 8192 , 128 , 4 , * quantized_dtypes , TP = 1 , EP = 1 , name = "llama4-maverick" )
233
+ rank , world_size = triton_dist .setup ()
234
+ if world_size > 1 :
235
+ # Running all workloads at once may cause OOM on some GPUs such as H100 80GB.
236
+ # Thus we request users to run each workload separately.
237
+ # For example, all eligible combinations of options are listed below when four GPUs are used:
238
+ # torchrun --nproc-per-node=4 ./bench_mlp.py --tp 2 --ep 2 --name llama4-maverick
239
+ # torchrun --nproc-per-node=4 ./bench_mlp.py --tp 1 --ep 4 --name llama4-maverick
240
+ # torchrun --nproc-per-node=4 ./bench_mlp.py --tp 4 --ep 1 --name llama4-maverick
241
+ # torchrun --nproc-per-node=4 ./bench_mlp.py --tp 4 --ep 1 --name dense
242
+ # torchrun --nproc-per-node=4 ./bench_mlp.py --tp 2 --ep 2 --name llama4-maverick --quantized
243
+ # torchrun --nproc-per-node=4 ./bench_mlp.py --tp 1 --ep 4 --name llama4-maverick --quantized
244
+ # torchrun --nproc-per-node=4 ./bench_mlp.py --tp 4 --ep 1 --name llama4-maverick --quantized
245
+ # torchrun --nproc-per-node=4 ./bench_mlp.py --tp 4 --ep 1 --name dense --quantized
246
+ argparse = argparse .ArgumentParser ()
247
+ argparse .add_argument ("--tp" , type = int , default = 1 )
248
+ argparse .add_argument ("--ep" , type = int , default = 1 )
249
+ argparse .add_argument ("--name" , type = str , choices = ["dense" , "llama4-maverick" ])
250
+ argparse .add_argument ("--quantized" , action = "store_true" , default = False )
251
+ args = argparse .parse_args ()
252
+ dtypes = dense_dtypes if args .quantized else quantized_dtypes
253
+ if args .name == "dense" :
254
+ assert args .ep == 1 , "EP must be 1 for dense"
255
+ roofline_mlp (batch_ranges_dense , 8192 , 8192 , 1 , 1 , * dtypes , TP = args .tp , EP = args .ep , name = "dense" )
256
+ else :
257
+ roofline_mlp (batch_ranges_moe , 5120 , 8192 , 128 , 4 , * dtypes , TP = args .tp , EP = args .ep , name = "llama4-maverick" )
258
+ triton_dist .cleanup ()
259
+ else :
260
+ roofline_mlp (batch_ranges_dense , 8192 , 8192 , 1 , 1 , * dense_dtypes , TP = 1 , EP = 1 , name = "dense" )
261
+ roofline_mlp (batch_ranges_dense , 8192 , 8192 , 1 , 1 , * quantized_dtypes , TP = 1 , EP = 1 , name = "dense" )
262
+ roofline_mlp (batch_ranges_moe , 5120 , 8192 , 128 , 4 , * dense_dtypes , TP = 1 , EP = 1 , name = "llama4-maverick" )
263
+ roofline_mlp (batch_ranges_moe , 5120 , 8192 , 128 , 4 , * quantized_dtypes , TP = 1 , EP = 1 , name = "llama4-maverick" )
0 commit comments