2020
2121sys .path .insert (0 , os .path .abspath (os .path .join (os .path .dirname (__file__ ), "../.." )))
2222
23- COEFFS_CFG = dict (tmax = 20 , rms_eps = 1e-6 , pre_eps = 0.0 , sinkhorn_eps = 1e-6 , post_mult = 2.0 )
2423B = 4
2524
2625
@@ -31,9 +30,15 @@ def bench_speed_mhc(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
3130 HC = input .extra_benchmark_config ["HC" ]
3231 C = input .extra_benchmark_config ["C" ]
3332 sub_kernel = input .extra_benchmark_config ["sub_kernel" ]
33+ tmax = input .extra_benchmark_config ["tmax" ]
34+ rms_eps = input .extra_benchmark_config ["rms_eps" ]
35+ pre_eps = input .extra_benchmark_config ["pre_eps" ]
36+ sinkhorn_eps = input .extra_benchmark_config ["sinkhorn_eps" ]
37+ post_mult = input .extra_benchmark_config ["post_mult" ]
3438 provider = input .kernel_provider
3539 mode = input .kernel_operation_mode
3640
41+ coeffs_cfg = dict (tmax = tmax , rms_eps = rms_eps , pre_eps = pre_eps , sinkhorn_eps = sinkhorn_eps , post_mult = post_mult )
3742 need_grad = mode in ("backward" , "full" )
3843
3944 x = torch .randn (B , T , HC , C , device = device , dtype = torch .bfloat16 , requires_grad = need_grad )
@@ -50,8 +55,8 @@ def bench_speed_mhc(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
5055
5156 def fwd ():
5257 if provider == "liger" :
53- return liger_mhc_coeffs (x , phi , b_param , alpha_pre , alpha_post , alpha_res , ** COEFFS_CFG )
54- return mhc_coeffs_ref (x , phi , b_param , alpha_pre , alpha_post , alpha_res , ** COEFFS_CFG )
58+ return liger_mhc_coeffs (x , phi , b_param , alpha_pre , alpha_post , alpha_res , ** coeffs_cfg )
59+ return mhc_coeffs_ref (x , phi , b_param , alpha_pre , alpha_post , alpha_res , ** coeffs_cfg )
5560
5661 def fwd_loss ():
5762 h_pre , h_post , h_res = fwd ()
@@ -66,7 +71,7 @@ def fwd_loss():
6671 alpha_pre .detach (),
6772 alpha_post .detach (),
6873 alpha_res .detach (),
69- ** COEFFS_CFG ,
74+ ** coeffs_cfg ,
7075 )
7176 h_pre_c .requires_grad_ (need_grad )
7277 grad_to_none = [x , h_pre_c ] if need_grad else None
@@ -88,7 +93,7 @@ def fwd_loss():
8893 alpha_pre .detach (),
8994 alpha_post .detach (),
9095 alpha_res .detach (),
91- ** COEFFS_CFG ,
96+ ** coeffs_cfg ,
9297 )
9398 h_post_c .requires_grad_ (need_grad )
9499 h_res_c .requires_grad_ (need_grad )
@@ -133,8 +138,15 @@ def bench_memory_mhc(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput
133138 HC = input .extra_benchmark_config ["HC" ]
134139 C = input .extra_benchmark_config ["C" ]
135140 sub_kernel = input .extra_benchmark_config ["sub_kernel" ]
141+ tmax = input .extra_benchmark_config ["tmax" ]
142+ rms_eps = input .extra_benchmark_config ["rms_eps" ]
143+ pre_eps = input .extra_benchmark_config ["pre_eps" ]
144+ sinkhorn_eps = input .extra_benchmark_config ["sinkhorn_eps" ]
145+ post_mult = input .extra_benchmark_config ["post_mult" ]
136146 provider = input .kernel_provider
137147
148+ coeffs_cfg = dict (tmax = tmax , rms_eps = rms_eps , pre_eps = pre_eps , sinkhorn_eps = sinkhorn_eps , post_mult = post_mult )
149+
138150 x = torch .randn (B , T , HC , C , device = device , dtype = torch .bfloat16 , requires_grad = True )
139151 K , M = HC * C , HC * HC + 2 * HC
140152 phi = (torch .randn (K , M , device = device , dtype = torch .bfloat16 ) * 0.02 ).requires_grad_ (True )
@@ -147,9 +159,9 @@ def bench_memory_mhc(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput
147159
148160 def full ():
149161 if provider == "liger" :
150- hp , hpo , hr = liger_mhc_coeffs (x , phi , b_param , alpha_pre , alpha_post , alpha_res , ** COEFFS_CFG )
162+ hp , hpo , hr = liger_mhc_coeffs (x , phi , b_param , alpha_pre , alpha_post , alpha_res , ** coeffs_cfg )
151163 else :
152- hp , hpo , hr = mhc_coeffs_ref (x , phi , b_param , alpha_pre , alpha_post , alpha_res , ** COEFFS_CFG )
164+ hp , hpo , hr = mhc_coeffs_ref (x , phi , b_param , alpha_pre , alpha_post , alpha_res , ** coeffs_cfg )
153165 (hp .square ().mean () + hpo .square ().mean () + hr .square ().mean ()).backward ()
154166
155167 elif sub_kernel == "pre" :
@@ -161,7 +173,7 @@ def full():
161173 alpha_pre .detach (),
162174 alpha_post .detach (),
163175 alpha_res .detach (),
164- ** COEFFS_CFG ,
176+ ** coeffs_cfg ,
165177 )
166178 h_pre_c .requires_grad_ (True )
167179
@@ -181,7 +193,7 @@ def full():
181193 alpha_pre .detach (),
182194 alpha_post .detach (),
183195 alpha_res .detach (),
184- ** COEFFS_CFG ,
196+ ** coeffs_cfg ,
185197 )
186198 h_post_c .requires_grad_ (True )
187199 h_res_c .requires_grad_ (True )
@@ -215,6 +227,10 @@ def full():
215227 "HC" : 4 ,
216228 "C" : 4096 ,
217229 "tmax" : 20 ,
230+ "rms_eps" : 1e-6 ,
231+ "pre_eps" : 0.0 ,
232+ "sinkhorn_eps" : 1e-6 ,
233+ "post_mult" : 2.0 ,
218234 "sub_kernel" : sub_kernel ,
219235 }
220236 ],
0 commit comments