Skip to content

Commit 94d5da7

Browse files
committed
Refactor MHC tests to simplify imports
1 parent aca1fb8 commit 94d5da7

File tree

1 file changed

+12
-15
lines changed

1 file changed

+12
-15
lines changed

test/transformers/test_mhc.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,9 @@
77
from test.utils import set_seed
88
from test.utils import supports_bfloat16
99

10-
from liger_kernel.transformers.functional import liger_mhc_coeffs as fn_mhc_coeffs
11-
from liger_kernel.transformers.functional import liger_mhc_coeffs as op_mhc_coeffs
12-
from liger_kernel.transformers.functional import liger_mhc_post_res as fn_mhc_post_res
13-
from liger_kernel.transformers.functional import liger_mhc_post_res as op_mhc_post_res
14-
from liger_kernel.transformers.functional import liger_mhc_pre as fn_mhc_pre
15-
from liger_kernel.transformers.functional import liger_mhc_pre as op_mhc_pre
10+
from liger_kernel.transformers.functional import liger_mhc_coeffs
11+
from liger_kernel.transformers.functional import liger_mhc_post_res
12+
from liger_kernel.transformers.functional import liger_mhc_pre
1613
from liger_kernel.transformers.mhc import LigerMHC
1714

1815
device = infer_device()
@@ -142,7 +139,7 @@ def test_mhc_coeffs_forward_backward(B, T, HC, C, phi_dtype, dtype, pre_post_tol
142139

143140
cfg = dict(tmax=8, rms_eps=1e-6, pre_eps=1e-4, sinkhorn_eps=1e-6, post_mult=2.0)
144141

145-
h_pre, h_post, h_res = op_mhc_coeffs(x, phi, b, alpha_pre, alpha_post, alpha_res, **cfg)
142+
h_pre, h_post, h_res = liger_mhc_coeffs(x, phi, b, alpha_pre, alpha_post, alpha_res, **cfg)
146143

147144
loss = h_pre.square().mean() + h_post.square().mean() + h_res.square().mean()
148145
loss.backward()
@@ -201,7 +198,7 @@ def test_mhc_coeffs_allow_fp32(B, T, HC, C, dtype, pre_post_tol, res_tol, grad_t
201198

202199
cfg = dict(tmax=8, rms_eps=1e-6, pre_eps=1e-4, sinkhorn_eps=1e-6, post_mult=2.0)
203200

204-
h_pre, h_post, h_res = op_mhc_coeffs(x, phi, b, alpha_pre, alpha_post, alpha_res, allow_fp32=True, **cfg)
201+
h_pre, h_post, h_res = liger_mhc_coeffs(x, phi, b, alpha_pre, alpha_post, alpha_res, allow_fp32=True, **cfg)
205202

206203
loss = h_pre.square().mean() + h_post.square().mean() + h_res.square().mean()
207204
loss.backward()
@@ -257,7 +254,7 @@ def test_mhc_coeffs_disallow_fp32():
257254
alpha_res = torch.tensor(1.0, device=device, dtype=torch.float32)
258255

259256
with pytest.raises(AssertionError):
260-
_ = op_mhc_coeffs(x, phi, b, alpha_pre, alpha_post, alpha_res)
257+
_ = liger_mhc_coeffs(x, phi, b, alpha_pre, alpha_post, alpha_res)
261258

262259

263260
@pytest.mark.skipif(device != "cuda", reason="CUDA required")
@@ -284,7 +281,7 @@ def test_mhc_coeffs_backward_allows_unused_outputs(B, T, HC, C, use_pre, use_pos
284281

285282
cfg = dict(tmax=4, rms_eps=1e-6, pre_eps=1e-4, sinkhorn_eps=1e-6, post_mult=2.0)
286283

287-
h_pre, h_post, h_res = op_mhc_coeffs(x, phi, b, alpha_pre, alpha_post, alpha_res, **cfg)
284+
h_pre, h_post, h_res = liger_mhc_coeffs(x, phi, b, alpha_pre, alpha_post, alpha_res, **cfg)
288285

289286
loss = torch.zeros((), device=device)
290287
if use_pre:
@@ -310,9 +307,9 @@ def test_mhc_pre_and_post_res_match_reference(B, T, HC, C, dtype, pre_post_tol,
310307
h_post = torch.rand(B, T, HC, device=device, dtype=torch.float32, requires_grad=True)
311308
h_res = torch.rand(B, T, HC, HC, device=device, dtype=torch.float32, requires_grad=True)
312309

313-
x_in = op_mhc_pre(x, h_pre)
310+
x_in = liger_mhc_pre(x, h_pre)
314311
f_out = torch.randn(B, T, C, device=device, dtype=dtype, requires_grad=True)
315-
x_out = op_mhc_post_res(x, f_out, h_post, h_res)
312+
x_out = liger_mhc_post_res(x, f_out, h_post, h_res)
316313

317314
x_in_ref = (x.float() * h_pre.unsqueeze(-1)).sum(dim=-2)
318315
x_out_ref = torch.einsum("...oi,...ic->...oc", h_res, x.float()) + h_post.unsqueeze(-1) * f_out.float().unsqueeze(
@@ -340,7 +337,7 @@ def test_liger_mhc_functional(B, T, HC, C, dtype, pre_post_tol, res_tol, grad_to
340337

341338
cfg = dict(tmax=4, rms_eps=1e-6, pre_eps=1e-4, sinkhorn_eps=1e-6, post_mult=2.0)
342339

343-
h_pre, h_post, h_res = fn_mhc_coeffs(x, phi, b, alpha_pre, alpha_post, alpha_res, **cfg)
340+
h_pre, h_post, h_res = liger_mhc_coeffs(x, phi, b, alpha_pre, alpha_post, alpha_res, **cfg)
344341
rh_pre, rh_post, rh_res = mhc_coeffs_ref(x, phi, b, alpha_pre, alpha_post, alpha_res, **cfg)
345342

346343
assert_verbose_allclose(h_pre.float(), rh_pre.float(), rtol=pre_post_tol, atol=pre_post_tol, extra_info="[h_pre]")
@@ -381,8 +378,8 @@ def test_liger_mhc_functional(B, T, HC, C, dtype, pre_post_tol, res_tol, grad_to
381378
h_res3 = h_res.detach().clone().requires_grad_(True)
382379
f_out = torch.randn(B, T, C, device=device, dtype=dtype, requires_grad=True)
383380

384-
x_in = fn_mhc_pre(x3, h_pre3)
385-
x_out = fn_mhc_post_res(x3, f_out, h_post3, h_res3)
381+
x_in = liger_mhc_pre(x3, h_pre3)
382+
x_out = liger_mhc_post_res(x3, f_out, h_post3, h_res3)
386383

387384
x_in_ref = (x3.float() * h_pre3.unsqueeze(-1)).sum(dim=-2)
388385
x_out_ref = torch.einsum("...oi,...ic->...oc", h_res3, x3.float()) + h_post3.unsqueeze(

0 commit comments

Comments
 (0)