@@ -78,17 +78,37 @@ def convert_dtype(dtype):
78
78
79
79
80
80
def matmul_launch_metadata (grid , kernel , args ):
81
+ from ..proton_opts import launch_metadata_allow_sync
82
+
81
83
ret = dict ()
82
84
M , N , K = args ["M" ], args ["N" ], args ["K" ]
83
85
Y , X , W = [t .base if isinstance (t , TensorDescriptor ) else t for t in [args ["Y" ], args ["X" ], args ["W" ]]]
86
+ tokens_per_expt = args .get ("TOKENS_PER_EXPT_FOR_ANNOTATION" )
84
87
hist = args ["ExptHist" ]
85
88
if hist is not None :
86
- n_tokens = float (hist .sum ())
87
- n_w_bytes = (W .numel () * W .element_size () // hist .numel ()) * (hist > 0 ).sum ()
89
+ # If annotation is given, use that to generate name for profiling.
90
+ if tokens_per_expt is not None :
91
+ n_rows = f"{ tokens_per_expt } *"
92
+ elif launch_metadata_allow_sync ():
93
+ n_rows = int (hist .float ().mean ())
94
+ else :
95
+ n_rows = "unknown"
96
+
97
+ if launch_metadata_allow_sync ():
98
+ n_tokens = float (hist .sum ())
99
+ n_w_bytes = (W .numel () * W .element_size () // hist .numel ()) * (hist > 0 ).sum ()
100
+ elif tokens_per_expt is not None :
101
+ n_tokens = tokens_per_expt * args ["N_EXPTS_TOT" ]
102
+ # This may not be totally correct (e.g., we might not be using all experts)
103
+ # but it's better than nothing.
104
+ n_w_bytes = W .numel () * W .element_size ()
105
+ else :
106
+ n_tokens = None
107
+ n_w_bytes = 0
88
108
89
109
# If annotation is given, use that to generate name for profiling.
90
110
tokens_per_expt = args .get ("TOKENS_PER_EXPT_FOR_ANNOTATION" )
91
- n_rows = f"{ tokens_per_expt } *" if tokens_per_expt is not None else int ( hist . float (). mean ())
111
+ n_rows = f"{ tokens_per_expt } *" if tokens_per_expt is not None else n_rows
92
112
else :
93
113
n_tokens = None
94
114
n_w_bytes = W .numel () * W .element_size ()
@@ -101,6 +121,10 @@ def matmul_launch_metadata(grid, kernel, args):
101
121
ep_subtile = args ["EPILOGUE_SUBTILE" ]
102
122
if ep_subtile is not None and ep_subtile > 1 :
103
123
ret ["name" ] += f" ep/{ ep_subtile } "
124
+
125
+ if hist is not None and n_tokens is None :
126
+ return ret # Don't fill metadata because we can't compute them properly.
127
+
104
128
fM = M if M is not None else n_tokens
105
129
fK = K if K is not None else n_tokens
106
130
ret [f"flops{ nbits } " ] = 2.0 * fM * N * fK
@@ -115,7 +139,7 @@ def matmul_launch_metadata(grid, kernel, args):
115
139
assert n_tokens is not None
116
140
n_expts_act = args ["N_EXPTS_ACT" ]
117
141
118
- if gindx is not None :
142
+ if ( gindx is not None ) and launch_metadata_allow_sync () :
119
143
# recreate inverse GatherIndx.
120
144
dst = torch .full_like (gindx , - 1 )
121
145
idx = torch .arange (len (gindx ), device = gindx .device , dtype = torch .int32 )
0 commit comments