11from pathlib import Path
22import json
3+ import triton
34import triton .profiler as proton
45import torch
56import triton_bench .swiglu
67from triton_bench .mxfp import downcast_to_mxfp
78from triton_bench .matmul_ogs import MicroscalingCtx , matmul_ogs , PrecisionConfig , FlexCtx
89from triton_bench .numerics import InFlexData
9- from triton_bench .routing import routing_torch , simulate_expert_sharded_routing
10+ from triton_bench .routing import routing , simulate_expert_sharded_routing
1011from triton_bench .meta import cuda_capability_geq
1112
12- if torch .cuda .is_available ():
13+
14+ def is_hip_cdna4 ():
15+ target = triton .runtime .driver .active .get_current_target ()
16+ return target .backend == 'hip' and target .arch == 'gfx950'
17+
18+
19+ if torch .cuda .is_available () and not is_hip_cdna4 ():
1320 from triton ._C .libtriton import nvidia
1421 cublas_workspace = torch .empty (32 * 1024 * 1024 , device = "cuda" , dtype = torch .uint8 )
1522 cublas = nvidia .cublas .CublasLt (cublas_workspace )
1825
1926
2027def _query_gpu_specs ():
28+ if is_hip_cdna4 ():
29+ # no spec data yet.
30+ return None
2131 import subprocess
2232 cmd = ["nvidia-smi" , "--query-gpu=name" , "--format=csv,noheader" , "-i=0" ]
2333 output = subprocess .check_output (cmd , stderr = subprocess .DEVNULL ).decode ().strip ()
@@ -86,17 +96,19 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype,
8696 for i in range (100 ):
8797 x = torch .randn ((batch , dim1 ), device = dev )
8898 x = x .to (wg .dtype if n_expts_tot > 1 else x_dtype )
89- # TODO: activate proton here when fast routing is done
99+ proton . activate ()
90100 if n_expts_tot > 1 :
91101 logits = matmul_ogs (x , wg , bg , precision_config = pcg )
92- rdata , gather_indx , scatter_indx = routing_torch (logits , n_expts_act )
102+ rdata , gather_indx , scatter_indx = routing (logits , n_expts_act )
93103 if EP > 1 :
104+ proton .deactivate ()
105+ # TODO: activate proton here when fast expert parallelism simulation is done
94106 m = logits .shape [0 ] * EP
95107 _ , rdata , gather_indx , scatter_indx = simulate_expert_sharded_routing (m , rdata , EP , device = dev )
108+ proton .activate ()
96109 x = x .to (x_dtype )
97110 else :
98111 rdata , gather_indx , scatter_indx = None , None , None
99- proton .activate ()
100112 # c0 = torch.empty((x.shape[0], w1.shape[-1]), device=dev, dtype=x.dtype)
101113 # c1 = torch.empty((x.shape[0], w2.shape[-1]), device=dev, dtype=x.dtype)
102114 # cublas.matmul(x, w1.squeeze(0), c0)
@@ -119,8 +131,10 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype,
119131 # TODO: proton should really be recording that in the json instead of
120132 # relying on the user to aggregate
121133 tot_time = sum (x ["metrics" ].get ("time (ns)" , 0 ) for x in data [0 ]["children" ])
122- min_time_flops = sum ([tot_flops [w ] / SPECS [f"MAX_TFLOPS{ w } " ] for w in [8 , 16 ]]) * 1e-3
123- min_time_bytes = tot_bytes / SPECS ["MAX_TBPS" ] * 1e-3
134+ min_time_flops = min_time_bytes = 0
135+ if SPECS is not None :
136+ min_time_flops = sum ([tot_flops [w ] / SPECS [f"MAX_TFLOPS{ w } " ] for w in [8 , 16 ]]) * 1e-3
137+ min_time_bytes = tot_bytes / SPECS ["MAX_TBPS" ] * 1e-3
124138 min_time = max (min_time_flops , min_time_bytes )
125139 util = min_time / tot_time
126140 tflops = sum ([tot_flops [w ] for w in [8 , 16 ]]) / tot_time * 1e-3
@@ -130,9 +144,9 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype,
130144
131145
132146if __name__ == "__main__" :
133- has_native_mx4 = torch .cuda .get_device_capability (0 )[0 ] >= 10
147+ has_native_mx4 = torch .cuda .get_device_capability (0 )[0 ] >= 10 or is_hip_cdna4 ()
134148 qxdtype = "fp8" if has_native_mx4 else "bf16"
135149 print (bench_mlp (8192 , 8192 , 8192 , 1 , 1 , "fp8" , "fp8" , TP = 1 , EP = 1 , name = "dense" ))
136150 print (bench_mlp (8192 , 8192 , 8192 , 1 , 1 , qxdtype , "mx4" , TP = 1 , EP = 1 , name = "dense" ))
137- print (bench_mlp (1024 , 5120 , 8192 , 128 , 4 , "fp8" , "fp8" , TP = 4 , EP = 2 , name = "llama4" ))
138- print (bench_mlp (1024 , 5120 , 8192 , 128 , 4 , qxdtype , "mx4" , TP = 4 , EP = 2 , name = "llama4" ))
151+ print (bench_mlp (2048 , 5120 , 8192 , 128 , 4 , "fp8" , "fp8" , TP = 4 , EP = 1 , name = "llama4" ))
152+ print (bench_mlp (2048 , 5120 , 8192 , 128 , 4 , qxdtype , "mx4" , TP = 4 , EP = 1 , name = "llama4" ))
0 commit comments