@@ -78,17 +78,37 @@ def convert_dtype(dtype):
7878
7979
8080def matmul_launch_metadata (grid , kernel , args ):
81+ from ..proton_opts import launch_metadata_allow_sync
82+
8183 ret = dict ()
8284 M , N , K = args ["M" ], args ["N" ], args ["K" ]
8385 Y , X , W = [t .base if isinstance (t , TensorDescriptor ) else t for t in [args ["Y" ], args ["X" ], args ["W" ]]]
86+ tokens_per_expt = args .get ("TOKENS_PER_EXPT_FOR_ANNOTATION" )
8487 hist = args ["ExptHist" ]
8588 if hist is not None :
86- n_tokens = float (hist .sum ())
87- n_w_bytes = (W .numel () * W .element_size () // hist .numel ()) * (hist > 0 ).sum ()
89+ # If annotation is given, use that to generate name for profiling.
90+ if tokens_per_expt is not None :
91+ n_rows = f"{ tokens_per_expt } *"
92+ elif launch_metadata_allow_sync ():
93+ n_rows = int (hist .float ().mean ())
94+ else :
95+ n_rows = "unknown"
96+
97+ if launch_metadata_allow_sync ():
98+ n_tokens = float (hist .sum ())
99+ n_w_bytes = (W .numel () * W .element_size () // hist .numel ()) * (hist > 0 ).sum ()
100+ elif tokens_per_expt is not None :
101+ n_tokens = tokens_per_expt * args ["N_EXPTS_TOT" ]
102+ # This may not be totally correct (e.g., we might not be using all experts)
103+ # but it's better than nothing.
104+ n_w_bytes = W .numel () * W .element_size ()
105+ else :
106+ n_tokens = None
107+ n_w_bytes = 0
88108
89109 # If annotation is given, use that to generate name for profiling.
90110 tokens_per_expt = args .get ("TOKENS_PER_EXPT_FOR_ANNOTATION" )
91- n_rows = f"{ tokens_per_expt } *" if tokens_per_expt is not None else int ( hist . float (). mean ())
111+ n_rows = f"{ tokens_per_expt } *" if tokens_per_expt is not None else n_rows
92112 else :
93113 n_tokens = None
94114 n_w_bytes = W .numel () * W .element_size ()
@@ -101,6 +121,10 @@ def matmul_launch_metadata(grid, kernel, args):
101121 ep_subtile = args ["EPILOGUE_SUBTILE" ]
102122 if ep_subtile is not None and ep_subtile > 1 :
103123 ret ["name" ] += f" ep/{ ep_subtile } "
124+
125+ if hist is not None and n_tokens is None :
126+ return ret # Don't fill metadata because we can't compute them properly.
127+
104128 fM = M if M is not None else n_tokens
105129 fK = K if K is not None else n_tokens
106130 ret [f"flops{ nbits } " ] = 2.0 * fM * N * fK
@@ -115,7 +139,7 @@ def matmul_launch_metadata(grid, kernel, args):
115139 assert n_tokens is not None
116140 n_expts_act = args ["N_EXPTS_ACT" ]
117141
118- if gindx is not None :
142+ if ( gindx is not None ) and launch_metadata_allow_sync () :
119143 # recreate inverse GatherIndx.
120144 dst = torch .full_like (gindx , - 1 )
121145 idx = torch .arange (len (gindx ), device = gindx .device , dtype = torch .int32 )
0 commit comments