Skip to content

Commit 5c9eca9

Browse files
committed
mhc: code quality cleanup across ops, tests, and benchmarks
- Remove no-op mask=True from Sinkhorn backward kernels - Drop unused rms_eps/pre_eps from ctx.meta in coeffs backward - Remove redundant .contiguous() calls inside @ensure_contiguous methods - Simplify grad_x reshape to use x_shape directly - Simplify device detection in LigerMHC to try/except pattern - Replace torch.allclose with assert_verbose_allclose in tests - Standardize seed to set_seed(42) across all tests - Merge test_mhc_coeffs_allow_fp32 into test_mhc_coeffs_forward_backward - Add backward coverage to test_mhc_pre_and_post_res_match_reference - Widen bf16 tolerance for layer.weight.grad and phi.grad in module test - Move hardcoded B into extra_benchmark_configs (benchmark_mhc.py) - Rename MiniMHCLM to BenchMiniMHCLM in benchmark_mhc_lm.py - Split _build_models into single-provider _build_model
1 parent af0e661 commit 5c9eca9

File tree

5 files changed

+90
-146
lines changed

5 files changed

+90
-146
lines changed

benchmark/scripts/benchmark_mhc.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,12 @@
2020

2121
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
2222

23-
B = 4
24-
2523

2624
def bench_speed_mhc(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
2725
from test.transformers.test_mhc import mhc_coeffs_ref
2826

2927
T = input.x
28+
B = input.extra_benchmark_config["B"]
3029
HC = input.extra_benchmark_config["HC"]
3130
C = input.extra_benchmark_config["C"]
3231
sub_kernel = input.extra_benchmark_config["sub_kernel"]
@@ -135,6 +134,7 @@ def bench_memory_mhc(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput
135134
from test.transformers.test_mhc import mhc_coeffs_ref
136135

137136
T = input.x
137+
B = input.extra_benchmark_config["B"]
138138
HC = input.extra_benchmark_config["HC"]
139139
C = input.extra_benchmark_config["C"]
140140
sub_kernel = input.extra_benchmark_config["sub_kernel"]
@@ -224,6 +224,7 @@ def full():
224224
"kernel_providers": ["liger", "torch"],
225225
"extra_benchmark_configs": [
226226
{
227+
"B": 4,
227228
"HC": 4,
228229
"C": 4096,
229230
"tmax": 20,

benchmark/scripts/benchmark_mhc_lm.py

Lines changed: 13 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
226226
return x
227227

228228

229-
class MiniMHCLM(nn.Module):
229+
class BenchMiniMHCLM(nn.Module):
230230
def __init__(
231231
self,
232232
mhc_cls: type[nn.Module],
@@ -274,7 +274,8 @@ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
274274
return self.lm_head(x)
275275

276276

277-
def _build_models(
277+
def _build_model(
278+
provider: str,
278279
*,
279280
hidden_size: int,
280281
hc: int,
@@ -285,8 +286,9 @@ def _build_models(
285286
tmax: int,
286287
dtype: torch.dtype,
287288
):
288-
liger_model = MiniMHCLM(
289-
LigerMHC,
289+
mhc_cls = LigerMHC if provider == "liger" else TorchMHC
290+
return BenchMiniMHCLM(
291+
mhc_cls,
290292
vocab_size=vocab_size,
291293
hidden_size=hidden_size,
292294
hc=hc,
@@ -297,20 +299,6 @@ def _build_models(
297299
dtype=dtype,
298300
device=device,
299301
)
300-
torch_model = MiniMHCLM(
301-
TorchMHC,
302-
vocab_size=vocab_size,
303-
hidden_size=hidden_size,
304-
hc=hc,
305-
num_layers=num_layers,
306-
num_heads=num_heads,
307-
intermediate_mult=intermediate_mult,
308-
tmax=tmax,
309-
dtype=dtype,
310-
device=device,
311-
)
312-
torch_model.load_state_dict(liger_model.state_dict())
313-
return liger_model, torch_model
314302

315303

316304
def bench_speed_mhc_lm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
@@ -331,7 +319,8 @@ def bench_speed_mhc_lm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutp
331319
if hidden_size % num_heads != 0:
332320
raise ValueError("hidden_size must be divisible by num_heads")
333321

334-
liger_model, torch_model = _build_models(
322+
model = _build_model(
323+
provider,
335324
hidden_size=hidden_size,
336325
hc=hc,
337326
num_layers=num_layers,
@@ -345,16 +334,12 @@ def bench_speed_mhc_lm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutp
345334
input_ids = torch.randint(0, vocab_size, (bsz, seq_len), device=device)
346335

347336
def fwd():
348-
if provider == "liger":
349-
return liger_model(input_ids)
350-
if provider == "torch":
351-
return torch_model(input_ids)
352-
raise ValueError(f"Unknown provider: {provider}")
337+
return model(input_ids)
353338

354339
def fwd_loss():
355340
return fwd().float().mean()
356341

357-
grad_to_none = list(liger_model.parameters()) if provider == "liger" else list(torch_model.parameters())
342+
grad_to_none = list(model.parameters())
358343

359344
if mode == "forward":
360345
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, grad_to_none=grad_to_none, rep=100)
@@ -400,7 +385,8 @@ def bench_memory_mhc_lm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOut
400385
if hidden_size % num_heads != 0:
401386
raise ValueError("hidden_size must be divisible by num_heads")
402387

403-
liger_model, torch_model = _build_models(
388+
model = _build_model(
389+
provider,
404390
hidden_size=hidden_size,
405391
hc=hc,
406392
num_layers=num_layers,
@@ -414,11 +400,7 @@ def bench_memory_mhc_lm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOut
414400
input_ids = torch.randint(0, vocab_size, (bsz, seq_len), device=device)
415401

416402
def fwd():
417-
if provider == "liger":
418-
return liger_model(input_ids)
419-
if provider == "torch":
420-
return torch_model(input_ids)
421-
raise ValueError(f"Unknown provider: {provider}")
403+
return model(input_ids)
422404

423405
def full():
424406
loss = fwd().float().mean()

src/liger_kernel/ops/mhc.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -546,8 +546,6 @@ def _mhc_sinkhorn_bwd_kernel(
546546
# Start backward from grad_out
547547
g = tl.load(
548548
grad_out_ptr + pid * stride_go_n + rows * stride_go_i + cols * stride_go_j,
549-
mask=True,
550-
other=0.0,
551549
).to(tl.float32)
552550

553551
# Reverse iterations (TMAX-1 .. 1), recomputing mat_t, rs_t, cs_t
@@ -641,8 +639,6 @@ def _mhc_sinkhorn_bwd_hist_kernel(
641639
# Start backward from grad_out
642640
g = tl.load(
643641
grad_out_ptr + pid * stride_go_n + rows * stride_go_i + cols * stride_go_j,
644-
mask=True,
645-
other=0.0,
646642
).to(tl.float32)
647643

648644
# Reverse iterations (TMAX-1 .. 1) using stored mats
@@ -1471,8 +1467,6 @@ def forward( # type: ignore[override]
14711467
HC,
14721468
C,
14731469
int(tmax),
1474-
float(rms_eps),
1475-
float(pre_eps),
14761470
float(sinkhorn_eps),
14771471
float(post_mult),
14781472
hist is not None,
@@ -1495,7 +1489,7 @@ def backward(
14951489
grad_h_res: torch.Tensor | None,
14961490
):
14971491
saved = ctx.saved_tensors
1498-
x_shape, HC, C, tmax, rms_eps, pre_eps, sinkhorn_eps, post_mult, has_hist = ctx.meta
1492+
x_shape, HC, C, tmax, sinkhorn_eps, post_mult, has_hist = ctx.meta
14991493
if has_hist:
15001494
x_mat, phi, b, mix, invr, alpha_pre, alpha_post, alpha_res, hist = saved
15011495
else:
@@ -1511,15 +1505,15 @@ def backward(
15111505

15121506
# flatten grads (None -> zeros)
15131507
if need_pre:
1514-
gh_pre = grad_h_pre.contiguous().view(-1, HC).to(torch.float32)
1508+
gh_pre = grad_h_pre.view(-1, HC).to(torch.float32)
15151509
else:
15161510
gh_pre = torch.zeros((N, HC), device=mix.device, dtype=torch.float32)
15171511
if need_post:
1518-
gh_post = grad_h_post.contiguous().view(-1, HC).to(torch.float32)
1512+
gh_post = grad_h_post.view(-1, HC).to(torch.float32)
15191513
else:
15201514
gh_post = torch.zeros((N, HC), device=mix.device, dtype=torch.float32)
15211515
if need_res:
1522-
gh_res = grad_h_res.contiguous().view(-1, HC, HC).to(torch.float32)
1516+
gh_res = grad_h_res.view(-1, HC, HC).to(torch.float32)
15231517
else:
15241518
gh_res = torch.zeros((N, HC, HC), device=mix.device, dtype=torch.float32)
15251519

@@ -1599,7 +1593,7 @@ def backward(
15991593
)
16001594

16011595
# Reshape to original shape
1602-
grad_x = grad_x_mat.view(*x_shape[:-2], HC, C)
1596+
grad_x = grad_x_mat.view(x_shape)
16031597

16041598
# Return grads for each forward input
16051599
return (
@@ -1624,7 +1618,7 @@ class LigerMHCPreFunction(torch.autograd.Function):
16241618
def forward(ctx: Any, x: torch.Tensor, h_pre: torch.Tensor) -> torch.Tensor:
16251619
x_shape = x.shape
16261620
x_flat, _ = _flatten_tokens(x)
1627-
h_pre_flat = h_pre.contiguous().view(-1, x_flat.shape[1]).to(torch.float32)
1621+
h_pre_flat = h_pre.view(-1, x_flat.shape[1]).to(torch.float32)
16281622
out = mhc_pre_fwd(x_flat, h_pre_flat) # [N,C] fp32
16291623
ctx.save_for_backward(x_flat, h_pre_flat)
16301624
ctx.x_shape = x_shape
@@ -1637,7 +1631,7 @@ def backward(ctx: Any, grad_out: torch.Tensor):
16371631
x_flat, h_pre_flat = ctx.saved_tensors
16381632
x_shape = ctx.x_shape
16391633
N, HC, C = x_flat.shape
1640-
go = grad_out.contiguous().view(-1, C).to(torch.float32)
1634+
go = grad_out.view(-1, C).to(torch.float32)
16411635
grad_x, grad_h = mhc_pre_bwd(x_flat, h_pre_flat, go)
16421636
grad_x = grad_x.to(x_flat.dtype)
16431637
return grad_x.view(*x_shape), grad_h.view(*x_shape[:-1])
@@ -1652,9 +1646,9 @@ def forward(
16521646
x_shape = x.shape
16531647
x_flat, _ = _flatten_tokens(x)
16541648
N, HC, C = x_flat.shape
1655-
f_flat = f_out.contiguous().view(-1, C)
1656-
h_post_flat = h_post.contiguous().view(-1, HC).to(torch.float32)
1657-
h_res_flat = h_res.contiguous().view(-1, HC, HC).to(torch.float32)
1649+
f_flat = f_out.view(-1, C)
1650+
h_post_flat = h_post.view(-1, HC).to(torch.float32)
1651+
h_res_flat = h_res.view(-1, HC, HC).to(torch.float32)
16581652
out = mhc_post_res_fwd(x_flat, f_flat, h_post_flat, h_res_flat) # [N,HC,C] fp32
16591653
ctx.save_for_backward(x_flat, f_flat, h_post_flat, h_res_flat)
16601654
ctx.x_shape = x_shape
@@ -1667,7 +1661,7 @@ def backward(ctx: Any, grad_out: torch.Tensor):
16671661
x_flat, f_flat, h_post_flat, h_res_flat = ctx.saved_tensors
16681662
x_shape = ctx.x_shape
16691663
N, HC, C = x_flat.shape
1670-
go = grad_out.contiguous().view(-1, HC, C).to(torch.float32)
1664+
go = grad_out.view(-1, HC, C).to(torch.float32)
16711665

16721666
grad_x, grad_f, grad_hpost, grad_hres = mhc_post_res_bwd(x_flat, f_flat, h_post_flat, h_res_flat, go)
16731667

src/liger_kernel/transformers/mhc.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -113,15 +113,9 @@ def __init__(
113113
m = hc * hc + 2 * hc
114114
k = hc * c
115115

116-
layer_device = None
117-
for param in self.layer.parameters(recurse=True):
118-
layer_device = param.device
119-
break
120-
if layer_device is None:
121-
for buf in self.layer.buffers(recurse=True):
122-
layer_device = buf.device
123-
break
124-
if layer_device is None:
116+
try:
117+
layer_device = next(self.layer.parameters()).device
118+
except StopIteration:
125119
layer_device = torch.device("cpu")
126120

127121
# Note: for best speed, keep phi in BF16/FP16 to enable tensor-core matmul in Triton.

0 commit comments

Comments
 (0)