77from test .utils import set_seed
88from test .utils import supports_bfloat16
99
10- from liger_kernel .transformers .functional import liger_mhc_coeffs as fn_mhc_coeffs
11- from liger_kernel .transformers .functional import liger_mhc_coeffs as op_mhc_coeffs
12- from liger_kernel .transformers .functional import liger_mhc_post_res as fn_mhc_post_res
13- from liger_kernel .transformers .functional import liger_mhc_post_res as op_mhc_post_res
14- from liger_kernel .transformers .functional import liger_mhc_pre as fn_mhc_pre
15- from liger_kernel .transformers .functional import liger_mhc_pre as op_mhc_pre
10+ from liger_kernel .transformers .functional import liger_mhc_coeffs
11+ from liger_kernel .transformers .functional import liger_mhc_post_res
12+ from liger_kernel .transformers .functional import liger_mhc_pre
1613from liger_kernel .transformers .mhc import LigerMHC
1714
1815device = infer_device ()
@@ -142,7 +139,7 @@ def test_mhc_coeffs_forward_backward(B, T, HC, C, phi_dtype, dtype, pre_post_tol
142139
143140 cfg = dict (tmax = 8 , rms_eps = 1e-6 , pre_eps = 1e-4 , sinkhorn_eps = 1e-6 , post_mult = 2.0 )
144141
145- h_pre , h_post , h_res = op_mhc_coeffs (x , phi , b , alpha_pre , alpha_post , alpha_res , ** cfg )
142+ h_pre , h_post , h_res = liger_mhc_coeffs (x , phi , b , alpha_pre , alpha_post , alpha_res , ** cfg )
146143
147144 loss = h_pre .square ().mean () + h_post .square ().mean () + h_res .square ().mean ()
148145 loss .backward ()
@@ -201,7 +198,7 @@ def test_mhc_coeffs_allow_fp32(B, T, HC, C, dtype, pre_post_tol, res_tol, grad_t
201198
202199 cfg = dict (tmax = 8 , rms_eps = 1e-6 , pre_eps = 1e-4 , sinkhorn_eps = 1e-6 , post_mult = 2.0 )
203200
204- h_pre , h_post , h_res = op_mhc_coeffs (x , phi , b , alpha_pre , alpha_post , alpha_res , allow_fp32 = True , ** cfg )
201+ h_pre , h_post , h_res = liger_mhc_coeffs (x , phi , b , alpha_pre , alpha_post , alpha_res , allow_fp32 = True , ** cfg )
205202
206203 loss = h_pre .square ().mean () + h_post .square ().mean () + h_res .square ().mean ()
207204 loss .backward ()
@@ -257,7 +254,7 @@ def test_mhc_coeffs_disallow_fp32():
257254 alpha_res = torch .tensor (1.0 , device = device , dtype = torch .float32 )
258255
259256 with pytest .raises (AssertionError ):
260- _ = op_mhc_coeffs (x , phi , b , alpha_pre , alpha_post , alpha_res )
257+ _ = liger_mhc_coeffs (x , phi , b , alpha_pre , alpha_post , alpha_res )
261258
262259
263260@pytest .mark .skipif (device != "cuda" , reason = "CUDA required" )
@@ -284,7 +281,7 @@ def test_mhc_coeffs_backward_allows_unused_outputs(B, T, HC, C, use_pre, use_pos
284281
285282 cfg = dict (tmax = 4 , rms_eps = 1e-6 , pre_eps = 1e-4 , sinkhorn_eps = 1e-6 , post_mult = 2.0 )
286283
287- h_pre , h_post , h_res = op_mhc_coeffs (x , phi , b , alpha_pre , alpha_post , alpha_res , ** cfg )
284+ h_pre , h_post , h_res = liger_mhc_coeffs (x , phi , b , alpha_pre , alpha_post , alpha_res , ** cfg )
288285
289286 loss = torch .zeros ((), device = device )
290287 if use_pre :
@@ -310,9 +307,9 @@ def test_mhc_pre_and_post_res_match_reference(B, T, HC, C, dtype, pre_post_tol,
310307 h_post = torch .rand (B , T , HC , device = device , dtype = torch .float32 , requires_grad = True )
311308 h_res = torch .rand (B , T , HC , HC , device = device , dtype = torch .float32 , requires_grad = True )
312309
313- x_in = op_mhc_pre (x , h_pre )
310+ x_in = liger_mhc_pre (x , h_pre )
314311 f_out = torch .randn (B , T , C , device = device , dtype = dtype , requires_grad = True )
315- x_out = op_mhc_post_res (x , f_out , h_post , h_res )
312+ x_out = liger_mhc_post_res (x , f_out , h_post , h_res )
316313
317314 x_in_ref = (x .float () * h_pre .unsqueeze (- 1 )).sum (dim = - 2 )
318315 x_out_ref = torch .einsum ("...oi,...ic->...oc" , h_res , x .float ()) + h_post .unsqueeze (- 1 ) * f_out .float ().unsqueeze (
@@ -340,7 +337,7 @@ def test_liger_mhc_functional(B, T, HC, C, dtype, pre_post_tol, res_tol, grad_to
340337
341338 cfg = dict (tmax = 4 , rms_eps = 1e-6 , pre_eps = 1e-4 , sinkhorn_eps = 1e-6 , post_mult = 2.0 )
342339
343- h_pre , h_post , h_res = fn_mhc_coeffs (x , phi , b , alpha_pre , alpha_post , alpha_res , ** cfg )
340+ h_pre , h_post , h_res = liger_mhc_coeffs (x , phi , b , alpha_pre , alpha_post , alpha_res , ** cfg )
344341 rh_pre , rh_post , rh_res = mhc_coeffs_ref (x , phi , b , alpha_pre , alpha_post , alpha_res , ** cfg )
345342
346343 assert_verbose_allclose (h_pre .float (), rh_pre .float (), rtol = pre_post_tol , atol = pre_post_tol , extra_info = "[h_pre]" )
@@ -381,8 +378,8 @@ def test_liger_mhc_functional(B, T, HC, C, dtype, pre_post_tol, res_tol, grad_to
381378 h_res3 = h_res .detach ().clone ().requires_grad_ (True )
382379 f_out = torch .randn (B , T , C , device = device , dtype = dtype , requires_grad = True )
383380
384- x_in = fn_mhc_pre (x3 , h_pre3 )
385- x_out = fn_mhc_post_res (x3 , f_out , h_post3 , h_res3 )
381+ x_in = liger_mhc_pre (x3 , h_pre3 )
382+ x_out = liger_mhc_post_res (x3 , f_out , h_post3 , h_res3 )
386383
387384 x_in_ref = (x3 .float () * h_pre3 .unsqueeze (- 1 )).sum (dim = - 2 )
388385 x_out_ref = torch .einsum ("...oi,...ic->...oc" , h_res3 , x3 .float ()) + h_post3 .unsqueeze (
0 commit comments