diff --git a/benchmark/scripts/benchmark_mhc.py b/benchmark/scripts/benchmark_mhc.py new file mode 100644 index 000000000..47cdd6336 --- /dev/null +++ b/benchmark/scripts/benchmark_mhc.py @@ -0,0 +1,255 @@ +import os +import sys + +import torch +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.transformers.functional import liger_mhc_coeffs +from liger_kernel.transformers.functional import liger_mhc_post_res +from liger_kernel.transformers.functional import liger_mhc_pre +from liger_kernel.utils import infer_device + +device = infer_device() + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + + +def bench_speed_mhc(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + from test.transformers.test_mhc import mhc_coeffs_ref + + T = input.x + B = input.extra_benchmark_config["B"] + HC = input.extra_benchmark_config["HC"] + C = input.extra_benchmark_config["C"] + sub_kernel = input.extra_benchmark_config["sub_kernel"] + tmax = input.extra_benchmark_config["tmax"] + rms_eps = input.extra_benchmark_config["rms_eps"] + pre_eps = input.extra_benchmark_config["pre_eps"] + sinkhorn_eps = input.extra_benchmark_config["sinkhorn_eps"] + post_mult = input.extra_benchmark_config["post_mult"] + provider = input.kernel_provider + mode = input.kernel_operation_mode + + coeffs_cfg = dict(tmax=tmax, rms_eps=rms_eps, pre_eps=pre_eps, sinkhorn_eps=sinkhorn_eps, post_mult=post_mult) + need_grad = mode in ("backward", "full") + + x = torch.randn(B, T, HC, C, device=device, dtype=torch.bfloat16, requires_grad=need_grad) + K, M = HC * C, HC * HC + 2 * HC + phi = (torch.randn(K, M, device=device, dtype=torch.bfloat16) * 0.02).requires_grad_(need_grad) + b_param = torch.zeros(M, device=device, dtype=torch.float32, requires_grad=need_grad) + alpha_pre = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=need_grad) + alpha_post = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=need_grad) + alpha_res = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=need_grad) + + grad_to_none = [x, phi, b_param, alpha_pre, alpha_post, alpha_res] if need_grad else None + + if sub_kernel == "coeffs": + + def fwd(): + if provider == "liger": + return liger_mhc_coeffs(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **coeffs_cfg) + return mhc_coeffs_ref(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **coeffs_cfg) + + def fwd_loss(): + h_pre, h_post, h_res = fwd() + return h_pre.square().mean() + h_post.square().mean() + h_res.square().mean() + + elif sub_kernel == "pre": + with torch.no_grad(): + h_pre_c, _, _ = liger_mhc_coeffs( + x.detach(), + phi.detach(), + b_param.detach(), + alpha_pre.detach(), + alpha_post.detach(), + alpha_res.detach(), + **coeffs_cfg, + ) + h_pre_c.requires_grad_(need_grad) + grad_to_none = [x, h_pre_c] if need_grad else None + + def fwd(): + if provider == "liger": + return liger_mhc_pre(x, h_pre_c) + return (x.float() * h_pre_c.unsqueeze(-1)).sum(dim=-2) + + def fwd_loss(): + return fwd().square().mean() + + elif sub_kernel == "post_res": + with torch.no_grad(): + _, h_post_c, h_res_c = liger_mhc_coeffs( + x.detach(), + phi.detach(), + b_param.detach(), + alpha_pre.detach(), + alpha_post.detach(), + alpha_res.detach(), + **coeffs_cfg, + ) + h_post_c.requires_grad_(need_grad) + h_res_c.requires_grad_(need_grad) + f_out = torch.randn(B, T, C, device=device, dtype=torch.bfloat16, requires_grad=need_grad) + grad_to_none = [x, f_out, h_post_c, h_res_c] if need_grad else None + + def fwd(): + if provider == "liger": + return liger_mhc_post_res(x, f_out, h_post_c, h_res_c) + return torch.einsum("...oi,...ic->...oc", h_res_c, x.float()) + h_post_c.unsqueeze( + -1 + ) * f_out.float().unsqueeze(-2) + + def fwd_loss(): + return fwd().square().mean() + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) + elif mode == "backward": + y = fwd_loss() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + grad_to_none=grad_to_none, + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd_loss() + y.backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, grad_to_none=grad_to_none, rep=100, quantiles=QUANTILES) + + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) + + +def bench_memory_mhc(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + from test.transformers.test_mhc import mhc_coeffs_ref + + T = input.x + B = input.extra_benchmark_config["B"] + HC = input.extra_benchmark_config["HC"] + C = input.extra_benchmark_config["C"] + sub_kernel = input.extra_benchmark_config["sub_kernel"] + tmax = input.extra_benchmark_config["tmax"] + rms_eps = input.extra_benchmark_config["rms_eps"] + pre_eps = input.extra_benchmark_config["pre_eps"] + sinkhorn_eps = input.extra_benchmark_config["sinkhorn_eps"] + post_mult = input.extra_benchmark_config["post_mult"] + provider = input.kernel_provider + + coeffs_cfg = dict(tmax=tmax, rms_eps=rms_eps, pre_eps=pre_eps, sinkhorn_eps=sinkhorn_eps, post_mult=post_mult) + + x = torch.randn(B, T, HC, C, device=device, dtype=torch.bfloat16, requires_grad=True) + K, M = HC * C, HC * HC + 2 * HC + phi = (torch.randn(K, M, device=device, dtype=torch.bfloat16) * 0.02).requires_grad_(True) + b_param = torch.zeros(M, device=device, dtype=torch.float32, requires_grad=True) + alpha_pre = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=True) + alpha_post = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=True) + alpha_res = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=True) + + if sub_kernel == "coeffs": + + def full(): + if provider == "liger": + hp, hpo, hr = liger_mhc_coeffs(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **coeffs_cfg) + else: + hp, hpo, hr = mhc_coeffs_ref(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **coeffs_cfg) + (hp.square().mean() + hpo.square().mean() + hr.square().mean()).backward() + + elif sub_kernel == "pre": + with torch.no_grad(): + h_pre_c, _, _ = liger_mhc_coeffs( + x.detach(), + phi.detach(), + b_param.detach(), + alpha_pre.detach(), + alpha_post.detach(), + alpha_res.detach(), + **coeffs_cfg, + ) + h_pre_c.requires_grad_(True) + + def full(): + if provider == "liger": + out = liger_mhc_pre(x, h_pre_c) + else: + out = (x.float() * h_pre_c.unsqueeze(-1)).sum(dim=-2) + out.square().mean().backward() + + elif sub_kernel == "post_res": + with torch.no_grad(): + _, h_post_c, h_res_c = liger_mhc_coeffs( + x.detach(), + phi.detach(), + b_param.detach(), + alpha_pre.detach(), + alpha_post.detach(), + alpha_res.detach(), + **coeffs_cfg, + ) + h_post_c.requires_grad_(True) + h_res_c.requires_grad_(True) + f_out = torch.randn(B, T, C, device=device, dtype=torch.bfloat16, requires_grad=True) + + def full(): + if provider == "liger": + out = liger_mhc_post_res(x, f_out, h_post_c, h_res_c) + else: + out = torch.einsum("...oi,...ic->...oc", h_res_c, x.float()) + h_post_c.unsqueeze( + -1 + ) * f_out.float().unsqueeze(-2) + out.square().mean().backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + for sub_kernel in ["coeffs", "pre", "post_res"]: + common_configs = { + "kernel_name": f"mhc_{sub_kernel}", + "x_name": "T", + "x_label": "Sequence Length (T)", + "x_values": [2**i for i in range(7, 12)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "B": 4, + "HC": 4, + "C": 4096, + "tmax": 20, + "rms_eps": 1e-6, + "pre_eps": 0.0, + "sinkhorn_eps": 1e-6, + "post_mult": 2.0, + "sub_kernel": sub_kernel, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_mhc, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + + run_benchmarks( + bench_test_fn=bench_memory_mhc, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_mhc_lm.py b/benchmark/scripts/benchmark_mhc_lm.py new file mode 100644 index 000000000..6330a0e1a --- /dev/null +++ b/benchmark/scripts/benchmark_mhc_lm.py @@ -0,0 +1,455 @@ +import os +import sys + +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.transformers.mhc import LigerMHC +from liger_kernel.utils import infer_device + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + +device = infer_device() + + +class RMSNorm(nn.Module): + def __init__(self, hidden_size: int, *, eps: float, dtype: torch.dtype, device: str): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + var = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(var + self.eps) + return x * self.weight + + +def _build_rope_cache(seq_len: int, head_dim: int, *, device: torch.device, dtype: torch.dtype): + inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2, device=device, dtype=torch.float32) / head_dim)) + positions = torch.arange(seq_len, device=device, dtype=torch.float32) + freqs = torch.einsum("i,j->ij", positions, inv_freq) + cos = freqs.cos().to(dtype) + sin = freqs.sin().to(dtype) + return cos, sin + + +def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: + x1 = x[..., ::2] + x2 = x[..., 1::2] + cos = cos[None, None, :, :] + sin = sin[None, None, :, :] + return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1) + + +class MiniLlamaAttention(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, *, dtype: torch.dtype, device: str): + super().__init__() + assert hidden_size % num_heads == 0 + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + assert self.head_dim % 2 == 0, "head_dim must be even for RoPE" + + self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False, dtype=dtype, device=device) + self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False, dtype=dtype, device=device) + self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False, dtype=dtype, device=device) + self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False, dtype=dtype, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + bsz, seq_len, _ = x.shape + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + + q = q.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + + cos, sin = _build_rope_cache(seq_len, self.head_dim, device=x.device, dtype=q.dtype) + q = _apply_rope(q, cos, sin) + k = _apply_rope(k, cos, sin) + + attn = F.scaled_dot_product_attention(q, k, v, is_causal=True) + attn = attn.transpose(1, 2).contiguous().view(bsz, seq_len, self.hidden_size) + return self.o_proj(attn) + + +class MiniLlamaMLP(nn.Module): + def __init__(self, hidden_size: int, intermediate_mult: int, *, dtype: torch.dtype, device: str): + super().__init__() + intermediate_size = hidden_size * intermediate_mult + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False, dtype=dtype, device=device) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False, dtype=dtype, device=device) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False, dtype=dtype, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class AttentionBlock(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, *, dtype: torch.dtype, device: str): + super().__init__() + self.norm = RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device) + self.attn = MiniLlamaAttention(hidden_size, num_heads, dtype=dtype, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.attn(self.norm(x)) + + +class MLPBlock(nn.Module): + def __init__(self, hidden_size: int, intermediate_mult: int, *, dtype: torch.dtype, device: str): + super().__init__() + self.norm = RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device) + self.mlp = MiniLlamaMLP(hidden_size, intermediate_mult, dtype=dtype, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.mlp(self.norm(x)) + + +class TorchMHC(nn.Module): + def __init__( + self, + layer: nn.Module, + *, + hc: int, + c: int, + tmax: int, + rms_eps: float, + pre_eps: float, + sinkhorn_eps: float, + post_mult: float, + phi_dtype: torch.dtype, + ): + super().__init__() + self.layer = layer + self.hc = int(hc) + self.c = int(c) + self.tmax = int(tmax) + self.rms_eps = float(rms_eps) + self.pre_eps = float(pre_eps) + self.sinkhorn_eps = float(sinkhorn_eps) + self.post_mult = float(post_mult) + + layer_param = next(layer.parameters()) + device = layer_param.device + + m = hc * hc + 2 * hc + k = hc * c + self.phi = nn.Parameter(torch.randn(k, m, dtype=phi_dtype, device=device) * 0.02) + self.b = nn.Parameter(torch.zeros(m, dtype=torch.float32, device=device)) + self.alpha_pre = nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device)) + self.alpha_post = nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device)) + self.alpha_res = nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device)) + + self.layer_dtype = layer_param.dtype + + def _coeffs(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + from test.transformers.test_mhc import mhc_coeffs_ref + + return mhc_coeffs_ref( + x, + self.phi, + self.b, + self.alpha_pre, + self.alpha_post, + self.alpha_res, + tmax=self.tmax, + rms_eps=self.rms_eps, + pre_eps=self.pre_eps, + sinkhorn_eps=self.sinkhorn_eps, + post_mult=self.post_mult, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h_pre, h_post, h_res = self._coeffs(x) + x_in = (x.float() * h_pre.unsqueeze(-1)).sum(dim=-2) + if x_in.dtype != self.layer_dtype: + x_in = x_in.to(self.layer_dtype) + f_out = self.layer(x_in) + x_out = torch.einsum("...oi,...ic->...oc", h_res, x.float()) + h_post.unsqueeze(-1) * f_out.float().unsqueeze( + -2 + ) + return x_out.to(x.dtype) + + +class MHCDecoderLayer(nn.Module): + def __init__( + self, + mhc_cls: type[nn.Module], + *, + hidden_size: int, + hc: int, + num_heads: int, + intermediate_mult: int, + tmax: int, + dtype: torch.dtype, + device: str, + ): + super().__init__() + attn = AttentionBlock(hidden_size, num_heads, dtype=dtype, device=device) + mlp = MLPBlock(hidden_size, intermediate_mult, dtype=dtype, device=device) + self.attn = mhc_cls( + attn, + hc=hc, + c=hidden_size, + tmax=tmax, + rms_eps=1e-6, + pre_eps=1e-4, + sinkhorn_eps=1e-6, + post_mult=2.0, + phi_dtype=dtype, + ) + self.mlp = mhc_cls( + mlp, + hc=hc, + c=hidden_size, + tmax=tmax, + rms_eps=1e-6, + pre_eps=1e-4, + sinkhorn_eps=1e-6, + post_mult=2.0, + phi_dtype=dtype, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.attn(x) + x = self.mlp(x) + return x + + +class BenchMiniMHCLM(nn.Module): + def __init__( + self, + mhc_cls: type[nn.Module], + *, + vocab_size: int, + hidden_size: int, + hc: int, + num_layers: int, + num_heads: int, + intermediate_mult: int, + tmax: int, + dtype: torch.dtype, + device: str, + ): + super().__init__() + self.hc = hc + self.hidden_size = hidden_size + self.embed = nn.Embedding(vocab_size, hc * hidden_size, dtype=dtype, device=device) + self.layers = nn.ModuleList( + [ + MHCDecoderLayer( + mhc_cls, + hidden_size=hidden_size, + hc=hc, + num_heads=num_heads, + intermediate_mult=intermediate_mult, + tmax=tmax, + dtype=dtype, + device=device, + ) + for _ in range(num_layers) + ] + ) + self.final_norm = RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device) + self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False, dtype=dtype, device=device) + + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + x = self.embed(input_ids) + bsz, seq_len, _ = x.shape + x = x.view(bsz, seq_len, self.hc, self.hidden_size) + for layer in self.layers: + x = layer(x) + x = x.mean(dim=-2) + x = self.final_norm(x) + return self.lm_head(x) + + +def _build_model( + provider: str, + *, + hidden_size: int, + hc: int, + num_layers: int, + num_heads: int, + intermediate_mult: int, + vocab_size: int, + tmax: int, + dtype: torch.dtype, +): + mhc_cls = LigerMHC if provider == "liger" else TorchMHC + return BenchMiniMHCLM( + mhc_cls, + vocab_size=vocab_size, + hidden_size=hidden_size, + hc=hc, + num_layers=num_layers, + num_heads=num_heads, + intermediate_mult=intermediate_mult, + tmax=tmax, + dtype=dtype, + device=device, + ) + + +def bench_speed_mhc_lm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + hidden_size = int(input.x) + provider = input.kernel_provider + mode = input.kernel_operation_mode + extra = input.extra_benchmark_config + bsz = extra["B"] + seq_len = extra["T"] + hc = extra["HC"] + num_layers = extra["layers"] + num_heads = extra["heads"] + vocab_size = extra["vocab"] + dtype = extra["dtype"] + tmax = extra["tmax"] + intermediate_mult = extra["intermediate_mult"] + + if hidden_size % num_heads != 0: + raise ValueError("hidden_size must be divisible by num_heads") + + model = _build_model( + provider, + hidden_size=hidden_size, + hc=hc, + num_layers=num_layers, + num_heads=num_heads, + intermediate_mult=intermediate_mult, + vocab_size=vocab_size, + tmax=tmax, + dtype=dtype, + ) + + input_ids = torch.randint(0, vocab_size, (bsz, seq_len), device=device) + + def fwd(): + return model(input_ids) + + def fwd_loss(): + return fwd().float().mean() + + grad_to_none = list(model.parameters()) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, grad_to_none=grad_to_none, rep=100) + elif mode == "backward": + loss = fwd_loss() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: loss.backward(retain_graph=True), + quantiles=QUANTILES, + grad_to_none=grad_to_none, + rep=100, + ) + elif mode == "full": + + def full(): + loss = fwd_loss() + loss.backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, grad_to_none=grad_to_none, rep=100) + else: + raise ValueError(f"Unknown mode: {mode}") + + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +def bench_memory_mhc_lm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + hidden_size = int(input.x) + provider = input.kernel_provider + extra = input.extra_benchmark_config + bsz = extra["B"] + seq_len = extra["T"] + hc = extra["HC"] + num_layers = extra["layers"] + num_heads = extra["heads"] + vocab_size = extra["vocab"] + dtype = extra["dtype"] + tmax = extra["tmax"] + intermediate_mult = extra["intermediate_mult"] + + if hidden_size % num_heads != 0: + raise ValueError("hidden_size must be divisible by num_heads") + + model = _build_model( + provider, + hidden_size=hidden_size, + hc=hc, + num_layers=num_layers, + num_heads=num_heads, + intermediate_mult=intermediate_mult, + vocab_size=vocab_size, + tmax=tmax, + dtype=dtype, + ) + + input_ids = torch.randint(0, vocab_size, (bsz, seq_len), device=device) + + def fwd(): + return model(input_ids) + + def full(): + loss = fwd().float().mean() + loss.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "mhc_llama_like_lm", + "x_name": "hidden_size", + "x_label": "hidden_size", + "x_values": [256, 512, 1024], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "B": 2, + "T": 256, + "HC": 4, + "layers": 2, + "heads": 8, + "vocab": 4096, + "dtype": torch.bfloat16, + "tmax": 8, + "intermediate_mult": 4, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_mhc_lm, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_mhc_lm, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/src/liger_kernel/ops/__init__.py b/src/liger_kernel/ops/__init__.py index 6a34b18b4..cc7d0b038 100644 --- a/src/liger_kernel/ops/__init__.py +++ b/src/liger_kernel/ops/__init__.py @@ -60,6 +60,9 @@ from liger_kernel.ops.layer_norm import layer_norm_backward # noqa: F401 from liger_kernel.ops.layer_norm import layer_norm_forward # noqa: F401 from liger_kernel.ops.llama4_rope import LigerLlama4RopeFunction # noqa: F401 +from liger_kernel.ops.mhc import LigerMHCCoeffsFunction # noqa: F401 +from liger_kernel.ops.mhc import LigerMHCPostResFunction # noqa: F401 +from liger_kernel.ops.mhc import LigerMHCPreFunction # noqa: F401 from liger_kernel.ops.multi_token_attention import LigerMultiTokenAttentionFunction # noqa: F401 from liger_kernel.ops.poly_norm import LigerPolyNormFunction # noqa: F401 from liger_kernel.ops.poly_norm import poly_norm_backward # noqa: F401 diff --git a/src/liger_kernel/ops/mhc.py b/src/liger_kernel/ops/mhc.py new file mode 100644 index 000000000..1a4569d33 --- /dev/null +++ b/src/liger_kernel/ops/mhc.py @@ -0,0 +1,1674 @@ +import math + +from typing import Any +from typing import Optional +from typing import Tuple +from typing import Union + +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.utils import ensure_contiguous + + +def _post_res_default_meta(c: int) -> Tuple[int, int, int, int]: + """ + Returns default (block_n, block_c, num_warps, num_stages) for post_res kernels. + Tuned for different hidden dimensions on NVIDIA GPUs. + """ + if c >= 4096: + return 32, 128, 8, 3 # (block_n, block_c, num_warps, num_stages) + if c >= 2048: + return 32, 128, 4, 2 + if c >= 1024: + return 32, 64, 4, 2 + return 32, 64, 2, 2 + + +def _post_res_meta( + c: int, + block_n: Optional[int], + block_c: Optional[int], + num_warps: Optional[int], + num_stages: Optional[int], +) -> Tuple[int, int, int, int]: + bn, bc, nw, ns = _post_res_default_meta(c) + return ( + bn if block_n is None else int(block_n), + bc if block_c is None else int(block_c), + nw if num_warps is None else int(num_warps), + ns if num_stages is None else int(num_stages), + ) + + +# ------------------------------------------------------------------------------------------------- +# (1) Coefficients: fused matmul + RMS scalar (Eq. 14–15) +# mix = (x @ phi) * rsqrt(mean(x^2) + eps) +# +# We provide two paths: +# - TC path: x BF16/FP16 and phi BF16/FP16 (Tensor Cores) +# - TF32-ish path: x cast to FP32 and phi FP32 (relies on Triton/arch for TF32) +# ------------------------------------------------------------------------------------------------- + + +@triton.jit +def _mhc_mm_norm_fwd_kernel( + x_ptr, + phi_ptr, + mix_ptr, + invr_ptr, + N: tl.constexpr, + K: tl.constexpr, + M: tl.constexpr, + stride_xn: tl.constexpr, + stride_xk: tl.constexpr, + stride_phik: tl.constexpr, + stride_phim: tl.constexpr, + stride_mn: tl.constexpr, + stride_mm: tl.constexpr, + eps: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_M: tl.constexpr, + CAST_FP32: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_m = tl.program_id(1) + + n_offs = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + m_offs = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + + acc = tl.zeros((BLOCK_N, BLOCK_M), tl.float32) + sumsq = tl.zeros((BLOCK_N,), tl.float32) + + for k0 in tl.static_range(0, K, BLOCK_K): + k_offs = k0 + tl.arange(0, BLOCK_K) + + x = tl.load( + x_ptr + n_offs[:, None] * stride_xn + k_offs[None, :] * stride_xk, + mask=(n_offs[:, None] < N) & (k_offs[None, :] < K), + other=0.0, + ) + if CAST_FP32: + x = x.to(tl.float32) + sumsq += tl.sum(x * x, axis=1) + else: + x_f = x.to(tl.float32) + sumsq += tl.sum(x_f * x_f, axis=1) + + phi = tl.load( + phi_ptr + k_offs[:, None] * stride_phik + m_offs[None, :] * stride_phim, + mask=(k_offs[:, None] < K) & (m_offs[None, :] < M), + other=0.0, + ) + if CAST_FP32: + phi = phi.to(tl.float32) + + acc += tl.dot(x, phi) + + invr = tl.rsqrt(sumsq / K + eps) + out = acc * invr[:, None] + + tl.store( + mix_ptr + n_offs[:, None] * stride_mn + m_offs[None, :] * stride_mm, + out, + mask=(n_offs[:, None] < N) & (m_offs[None, :] < M), + ) + if pid_m == 0: + tl.store(invr_ptr + n_offs, invr, mask=n_offs < N) + + +def mhc_mm_norm_fwd( + x: torch.Tensor, + phi: torch.Tensor, + eps: float, + *, + out_mix: Optional[torch.Tensor] = None, + out_invr: Optional[torch.Tensor] = None, + block_n: int = 32, + block_k: int = 256, + block_m: int = 32, + num_warps: int = 4, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Fused (x @ phi) + invr = rsqrt(mean(x^2)+eps) and returns mix=(x@phi)*invr. + + Args: + x: [N, K] contiguous + phi: [K, M] contiguous + eps: float + Returns: + mix: [N, M] float32 + invr: [N] float32 + """ + assert x.is_contiguous(), "x must be contiguous" + assert phi.is_contiguous(), "phi must be contiguous" + + N, K = x.shape + K2, M = phi.shape + assert K2 == K, f"phi.shape[0] must match K: got {K2} vs {K}" + + if out_mix is None: + out_mix = torch.empty((N, M), device=x.device, dtype=torch.float32) + if out_invr is None: + out_invr = torch.empty((N,), device=x.device, dtype=torch.float32) + + grid = (triton.cdiv(N, block_n), triton.cdiv(M, block_m)) + + use_tc = (x.dtype == phi.dtype) and (x.dtype in (torch.float16, torch.bfloat16)) + + _mhc_mm_norm_fwd_kernel[grid]( + x, + phi, + out_mix, + out_invr, + N=N, + K=K, + M=M, + stride_xn=x.stride(0), + stride_xk=x.stride(1), + stride_phik=phi.stride(0), + stride_phim=phi.stride(1), + stride_mn=out_mix.stride(0), + stride_mm=out_mix.stride(1), + eps=eps, + BLOCK_N=block_n, + BLOCK_K=block_k, + BLOCK_M=block_m, + CAST_FP32=not use_tc, + num_warps=num_warps, + ) + return out_mix, out_invr + + +# ------------------------------------------------------------------------------------------------- +# Backward for fused (x @ phi) + RMS scalar +# +# mix = (x @ phi) * invr +# invr = rsqrt(mean(x^2) + eps) +# +# Given grad_mix, compute: +# grad_z = grad_mix * invr +# g = sum(grad_mix * (mix / invr)) = sum(grad_mix * mix) / invr +# factor = -(g / K) * invr^3 +# grad_x = grad_z @ phi^T + factor * x +# grad_phi = x^T @ grad_z +# +# grad_phi is accumulated into FP32 with atomic adds (split over N-chunks). +# ------------------------------------------------------------------------------------------------- + + +@triton.jit +def _mhc_mm_norm_bwd_fused_kernel( + x_ptr, + phi_ptr, + mix_ptr, + invr_ptr, + grad_mix_ptr, + grad_x_ptr, + grad_phi_ptr, + N: tl.constexpr, + K: tl.constexpr, + M: tl.constexpr, + stride_xn: tl.constexpr, + stride_xk: tl.constexpr, + stride_phik: tl.constexpr, + stride_phim: tl.constexpr, + stride_mn: tl.constexpr, + stride_mm: tl.constexpr, + stride_invr: tl.constexpr, + stride_gmn: tl.constexpr, + stride_gmm: tl.constexpr, + stride_gxn: tl.constexpr, + stride_gxk: tl.constexpr, + stride_gpk: tl.constexpr, + stride_gpm: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_M: tl.constexpr, + CAST_FP32: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_k = tl.program_id(1) + + n_offs = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + k_offs = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) + + invr = tl.load(invr_ptr + n_offs * stride_invr, mask=n_offs < N, other=0.0).to(tl.float32) + + x = tl.load( + x_ptr + n_offs[:, None] * stride_xn + k_offs[None, :] * stride_xk, + mask=(n_offs[:, None] < N) & (k_offs[None, :] < K), + other=0.0, + ) + if CAST_FP32: + x = x.to(tl.float32) + x_f = x + else: + x_f = x.to(tl.float32) + + acc = tl.zeros((BLOCK_N, BLOCK_K), tl.float32) + g_acc = tl.zeros((BLOCK_N,), tl.float32) + + for m0 in tl.static_range(0, M, BLOCK_M): + m_offs = m0 + tl.arange(0, BLOCK_M) + + grad_mix = tl.load( + grad_mix_ptr + n_offs[:, None] * stride_gmn + m_offs[None, :] * stride_gmm, + mask=(n_offs[:, None] < N) & (m_offs[None, :] < M), + other=0.0, + ).to(tl.float32) + + mix = tl.load( + mix_ptr + n_offs[:, None] * stride_mn + m_offs[None, :] * stride_mm, + mask=(n_offs[:, None] < N) & (m_offs[None, :] < M), + other=0.0, + ).to(tl.float32) + + g_acc += tl.sum(grad_mix * mix, axis=1) + + phi = tl.load( + phi_ptr + k_offs[:, None] * stride_phik + m_offs[None, :] * stride_phim, + mask=(k_offs[:, None] < K) & (m_offs[None, :] < M), + other=0.0, + ) + if CAST_FP32: + phi = phi.to(tl.float32) + grad_z = grad_mix * invr[:, None] + else: + grad_z = (grad_mix * invr[:, None]).to(phi.dtype) + + acc += tl.dot(grad_z, tl.trans(phi)) + + dphi = tl.dot(tl.trans(x), grad_z) + tl.atomic_add( + grad_phi_ptr + k_offs[:, None] * stride_gpk + m_offs[None, :] * stride_gpm, + dphi, + mask=(k_offs[:, None] < K) & (m_offs[None, :] < M), + ) + + g = g_acc / invr + invr3 = invr * invr * invr + factor = (-g * invr3) / K + + gx = acc + x_f * factor[:, None] + + if CAST_FP32: + tl.store( + grad_x_ptr + n_offs[:, None] * stride_gxn + k_offs[None, :] * stride_gxk, + gx, + mask=(n_offs[:, None] < N) & (k_offs[None, :] < K), + ) + else: + tl.store( + grad_x_ptr + n_offs[:, None] * stride_gxn + k_offs[None, :] * stride_gxk, + gx.to(x.dtype), + mask=(n_offs[:, None] < N) & (k_offs[None, :] < K), + ) + + +def mhc_mm_norm_bwd( + x: torch.Tensor, + phi: torch.Tensor, + mix: torch.Tensor, + invr: torch.Tensor, + grad_mix: torch.Tensor, + *, + out_grad_x: Optional[torch.Tensor] = None, + out_grad_phi: Optional[torch.Tensor] = None, + block_n: int = 32, + block_k: int = 256, + block_m: int = 32, + num_warps: int = 4, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Triton backward for `mhc_mm_norm_fwd`. + + Returns: + grad_x: [N, K] same dtype as x + grad_phi: [K, M] FP32 (safe for atomic adds; cast on return if needed) + + Note: + grad_phi is accumulated via atomic_add in FP32. For very large N + (batch * sequence length > 1M), accumulated rounding errors may + become noticeable. This is typically not an issue for standard + training configurations. + """ + assert ( + x.is_contiguous() + and phi.is_contiguous() + and mix.is_contiguous() + and invr.is_contiguous() + and grad_mix.is_contiguous() + ) + + N, K = x.shape + K2, M = phi.shape + assert K2 == K + assert mix.shape == (N, M) + assert grad_mix.shape == (N, M) + assert invr.shape == (N,) + + if out_grad_x is None: + out_grad_x = torch.empty_like(x) + if out_grad_phi is None: + out_grad_phi = torch.zeros((K, M), device=x.device, dtype=torch.float32) + + use_tc = (x.dtype == phi.dtype) and (x.dtype in (torch.float16, torch.bfloat16)) + + grid = (triton.cdiv(N, block_n), triton.cdiv(K, block_k)) + _mhc_mm_norm_bwd_fused_kernel[grid]( + x, + phi, + mix, + invr, + grad_mix, + out_grad_x, + out_grad_phi, + N=N, + K=K, + M=M, + stride_xn=x.stride(0), + stride_xk=x.stride(1), + stride_phik=phi.stride(0), + stride_phim=phi.stride(1), + stride_mn=mix.stride(0), + stride_mm=mix.stride(1), + stride_invr=invr.stride(0), + stride_gmn=grad_mix.stride(0), + stride_gmm=grad_mix.stride(1), + stride_gxn=out_grad_x.stride(0), + stride_gxk=out_grad_x.stride(1), + stride_gpk=out_grad_phi.stride(0), + stride_gpm=out_grad_phi.stride(1), + BLOCK_N=block_n, + BLOCK_K=block_k, + BLOCK_M=block_m, + CAST_FP32=not use_tc, + num_warps=num_warps, + ) + + if out_grad_phi.dtype != phi.dtype: + out_grad_phi = out_grad_phi.to(phi.dtype) + return out_grad_x, out_grad_phi + + +# ------------------------------------------------------------------------------------------------- +# Sinkhorn-Knopp forward/backward for H_res (Eq. 19) +# ------------------------------------------------------------------------------------------------- + + +@triton.jit +def _mhc_split_sinkhorn_fwd_kernel( + mix_ptr, + b_ptr, + hpre_ptr, + hpost_ptr, + hres_ptr, + hist_ptr, + N: tl.constexpr, + HC: tl.constexpr, + M: tl.constexpr, + stride_mn: tl.constexpr, + stride_mm: tl.constexpr, + stride_hp_n: tl.constexpr, + stride_hp_h: tl.constexpr, + stride_hq_n: tl.constexpr, + stride_hq_h: tl.constexpr, + stride_hr_n: tl.constexpr, + stride_hr_i: tl.constexpr, + stride_hr_j: tl.constexpr, + stride_hn: tl.constexpr, + stride_ht: tl.constexpr, + stride_hi: tl.constexpr, + stride_hj: tl.constexpr, + alpha_pre_ptr, + alpha_post_ptr, + alpha_res_ptr, + pre_eps: tl.constexpr, + sinkhorn_eps: tl.constexpr, + post_mult: tl.constexpr, + TMAX: tl.constexpr, + STORE_HIST: tl.constexpr, +): + pid = tl.program_id(0) + if pid >= N: + return + + # Load scalar alpha parameters from GPU memory (avoids CPU sync) + alpha_pre = tl.load(alpha_pre_ptr).to(tl.float32) + alpha_post = tl.load(alpha_post_ptr).to(tl.float32) + alpha_res = tl.load(alpha_res_ptr).to(tl.float32) + + # --- Pre/post logits + j = tl.arange(0, HC) + mix_pre = tl.load(mix_ptr + pid * stride_mn + j * stride_mm).to(tl.float32) + mix_post = tl.load(mix_ptr + pid * stride_mn + (HC + j) * stride_mm).to(tl.float32) + + b_pre = tl.load(b_ptr + j).to(tl.float32) + b_post = tl.load(b_ptr + (HC + j)).to(tl.float32) + + pre_logits = mix_pre * alpha_pre + b_pre + post_logits = mix_post * alpha_post + b_post + + pre = tl.sigmoid(pre_logits) + pre_eps + post = tl.sigmoid(post_logits) * post_mult + + tl.store(hpre_ptr + pid * stride_hp_n + j * stride_hp_h, pre) + tl.store(hpost_ptr + pid * stride_hq_n + j * stride_hq_h, post) + + # --- Residual logits matrix [HC, HC] + rows = tl.arange(0, HC)[:, None] + cols = tl.arange(0, HC)[None, :] + flat = rows * HC + cols # [HC,HC] + + mix_res = tl.load(mix_ptr + pid * stride_mn + (2 * HC + flat) * stride_mm).to(tl.float32) + b_res = tl.load(b_ptr + (2 * HC + flat)).to(tl.float32) + + logits = mix_res * alpha_res + b_res + + # Sinkhorn: initial row-softmax (stable) then alternating row/col norms. + row_max = tl.max(logits, axis=1) + e = tl.exp(logits - row_max[:, None]) + row_sum = tl.sum(e, axis=1) + mat = e / row_sum[:, None] + sinkhorn_eps + + col_sum = tl.sum(mat, axis=0) + mat = mat / (col_sum[None, :] + sinkhorn_eps) + + if STORE_HIST: + tl.store( + hist_ptr + pid * stride_hn + 0 * stride_ht + rows * stride_hi + cols * stride_hj, + mat, + ) + + for t in tl.static_range(0, TMAX - 1): + row_sum = tl.sum(mat, axis=1) + mat = mat / (row_sum[:, None] + sinkhorn_eps) + col_sum = tl.sum(mat, axis=0) + mat = mat / (col_sum[None, :] + sinkhorn_eps) + if STORE_HIST: + tl.store( + hist_ptr + pid * stride_hn + (t + 1) * stride_ht + rows * stride_hi + cols * stride_hj, + mat, + ) + + # Store h_res [N, HC, HC] (row-major: out, in) + tl.store(hres_ptr + pid * stride_hr_n + rows * stride_hr_i + cols * stride_hr_j, mat) + + +@triton.jit +def _mhc_sinkhorn_bwd_kernel( + mix_ptr, + b_ptr, + grad_out_ptr, + grad_logits_ptr, + N: tl.constexpr, + HC: tl.constexpr, + stride_mn: tl.constexpr, + stride_mm: tl.constexpr, + stride_go_n: tl.constexpr, + stride_go_i: tl.constexpr, + stride_go_j: tl.constexpr, + stride_gl_n: tl.constexpr, + stride_gl_i: tl.constexpr, + stride_gl_j: tl.constexpr, + alpha_res_ptr, + sinkhorn_eps: tl.constexpr, + TMAX: tl.constexpr, +): + pid = tl.program_id(0) + if pid >= N: + return + + alpha_res = tl.load(alpha_res_ptr).to(tl.float32) + + rows = tl.arange(0, HC)[:, None] + cols = tl.arange(0, HC)[None, :] + flat = rows * HC + cols + + # Rebuild logits + mix_res = tl.load(mix_ptr + pid * stride_mn + (2 * HC + flat) * stride_mm).to(tl.float32) + b_res = tl.load(b_ptr + (2 * HC + flat)).to(tl.float32) + logits = mix_res * alpha_res + b_res + + # Forward recompute (no lists) and backward with recompute per step. + row_max = tl.max(logits, axis=1) + e = tl.exp(logits - row_max[:, None]) + row_sum0 = tl.sum(e, axis=1) + p = e / row_sum0[:, None] # softmax, row-wise + p_eps = p + sinkhorn_eps + + col_sum0 = tl.sum(p_eps, axis=0) + mat0 = p_eps / (col_sum0[None, :] + sinkhorn_eps) + + # Start backward from grad_out + g = tl.load( + grad_out_ptr + pid * stride_go_n + rows * stride_go_i + cols * stride_go_j, + ).to(tl.float32) + + # Reverse iterations (TMAX-1 .. 1), recomputing mat_t, rs_t, cs_t + for t in tl.static_range(TMAX - 1, 0, -1): + mat = mat0 + rs_t = row_sum0 + cs_t = col_sum0 + mat_t = mat0 + + for s in tl.static_range(1, TMAX): + rs = tl.sum(mat, axis=1) + mat = mat / (rs[:, None] + sinkhorn_eps) + cs = tl.sum(mat, axis=0) + mat = mat / (cs[None, :] + sinkhorn_eps) + if s == t: + mat_t = mat + rs_t = rs + cs_t = cs + + denom_col = cs_t + sinkhorn_eps # [HC] + dot_col = tl.sum(g * mat_t, axis=0) # [HC] + g_row = (g - dot_col[None, :]) / denom_col[None, :] + + m_row = mat_t * denom_col[None, :] # invert col norm: m_row = m_out * denom + denom_row = rs_t + sinkhorn_eps + dot_row = tl.sum(g_row * m_row, axis=1) + g = (g_row - dot_row[:, None]) / denom_row[:, None] + + # Undo initial col norm (t=0) + denom_col0 = col_sum0 + sinkhorn_eps + dot_col0 = tl.sum(g * mat0, axis=0) + g_p = (g - dot_col0[None, :]) / denom_col0[None, :] + + # Softmax backward on rows: p * (g_p - sum(g_p * p)) + dot_soft = tl.sum(g_p * p, axis=1) + grad_logits = p * (g_p - dot_soft[:, None]) + + tl.store(grad_logits_ptr + pid * stride_gl_n + rows * stride_gl_i + cols * stride_gl_j, grad_logits) + + +@triton.jit +def _mhc_sinkhorn_bwd_hist_kernel( + mix_ptr, + b_ptr, + hist_ptr, + grad_out_ptr, + grad_logits_ptr, + N: tl.constexpr, + HC: tl.constexpr, + stride_mn: tl.constexpr, + stride_mm: tl.constexpr, + stride_hn: tl.constexpr, + stride_ht: tl.constexpr, + stride_hi: tl.constexpr, + stride_hj: tl.constexpr, + stride_go_n: tl.constexpr, + stride_go_i: tl.constexpr, + stride_go_j: tl.constexpr, + stride_gl_n: tl.constexpr, + stride_gl_i: tl.constexpr, + stride_gl_j: tl.constexpr, + alpha_res_ptr, + sinkhorn_eps: tl.constexpr, + TMAX: tl.constexpr, +): + pid = tl.program_id(0) + if pid >= N: + return + + alpha_res = tl.load(alpha_res_ptr).to(tl.float32) + + rows = tl.arange(0, HC)[:, None] + cols = tl.arange(0, HC)[None, :] + flat = rows * HC + cols + + # Rebuild logits + mix_res = tl.load(mix_ptr + pid * stride_mn + (2 * HC + flat) * stride_mm).to(tl.float32) + b_res = tl.load(b_ptr + (2 * HC + flat)).to(tl.float32) + logits = mix_res * alpha_res + b_res + + # Initial row-softmax + row_max = tl.max(logits, axis=1) + e = tl.exp(logits - row_max[:, None]) + row_sum0 = tl.sum(e, axis=1) + p = e / row_sum0[:, None] + p_eps = p + sinkhorn_eps + + col_sum0 = tl.sum(p_eps, axis=0) + mat0 = p_eps / (col_sum0[None, :] + sinkhorn_eps) + + # Start backward from grad_out + g = tl.load( + grad_out_ptr + pid * stride_go_n + rows * stride_go_i + cols * stride_go_j, + ).to(tl.float32) + + # Reverse iterations (TMAX-1 .. 1) using stored mats + for t in tl.static_range(TMAX - 1, 0, -1): + mat_t = tl.load(hist_ptr + pid * stride_hn + t * stride_ht + rows * stride_hi + cols * stride_hj).to(tl.float32) + mat_prev = tl.load(hist_ptr + pid * stride_hn + (t - 1) * stride_ht + rows * stride_hi + cols * stride_hj).to( + tl.float32 + ) + + row_sum = tl.sum(mat_prev, axis=1) + mat_row = mat_prev / (row_sum[:, None] + sinkhorn_eps) + col_sum = tl.sum(mat_row, axis=0) + denom_col = col_sum + sinkhorn_eps + + dot_col = tl.sum(g * mat_t, axis=0) + g_row = (g - dot_col[None, :]) / denom_col[None, :] + + m_row = mat_t * denom_col[None, :] + denom_row = row_sum + sinkhorn_eps + dot_row = tl.sum(g_row * m_row, axis=1) + g = (g_row - dot_row[:, None]) / denom_row[:, None] + + # Undo initial col norm (t=0) + denom_col0 = col_sum0 + sinkhorn_eps + dot_col0 = tl.sum(g * mat0, axis=0) + g_p = (g - dot_col0[None, :]) / denom_col0[None, :] + + # Softmax backward on rows: p * (g_p - sum(g_p * p)) + dot_soft = tl.sum(g_p * p, axis=1) + grad_logits = p * (g_p - dot_soft[:, None]) + + tl.store(grad_logits_ptr + pid * stride_gl_n + rows * stride_gl_i + cols * stride_gl_j, grad_logits) + + +def mhc_split_sinkhorn_fwd( + mix: torch.Tensor, + b: torch.Tensor, + alpha_pre: torch.Tensor, + alpha_post: torch.Tensor, + alpha_res: torch.Tensor, + *, + tmax: int, + pre_eps: float, + sinkhorn_eps: float, + post_mult: float, + out_hpre: Optional[torch.Tensor] = None, + out_hpost: Optional[torch.Tensor] = None, + out_hres: Optional[torch.Tensor] = None, + out_hist: Optional[torch.Tensor] = None, + return_hist: bool = False, + num_warps: int = 1, +) -> Union[ + Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] +]: + """ + Compute h_pre, h_post, h_res from `mix` (already normalized by RMS scalar). + + mix: [N, M] float32 where M = HC*HC + 2*HC + b: [M] float32 + """ + assert mix.is_contiguous() and b.is_contiguous() + + N, M = mix.shape + assert M == b.numel() + # infer HC from M = HC*HC + 2*HC + # Solve HC^2 + 2HC - M = 0 + HC = int((math.isqrt(4 + 4 * M) - 2) // 2) + assert HC * HC + 2 * HC == M, f"Invalid M for mHC: M={M}" + + if out_hpre is None: + out_hpre = torch.empty((N, HC), device=mix.device, dtype=torch.float32) + if out_hpost is None: + out_hpost = torch.empty((N, HC), device=mix.device, dtype=torch.float32) + if out_hres is None: + out_hres = torch.empty((N, HC, HC), device=mix.device, dtype=torch.float32) + if return_hist: + if out_hist is None: + out_hist = torch.empty((N, tmax, HC, HC), device=mix.device, dtype=torch.float32) + else: + if out_hist is None: + out_hist = torch.empty((1,), device=mix.device, dtype=torch.float32) + + grid = (N,) + + _mhc_split_sinkhorn_fwd_kernel[grid]( + mix, + b, + out_hpre, + out_hpost, + out_hres, + out_hist, + N=N, + HC=HC, + M=M, + stride_mn=mix.stride(0), + stride_mm=mix.stride(1), + stride_hp_n=out_hpre.stride(0), + stride_hp_h=out_hpre.stride(1), + stride_hq_n=out_hpost.stride(0), + stride_hq_h=out_hpost.stride(1), + stride_hr_n=out_hres.stride(0), + stride_hr_i=out_hres.stride(1), + stride_hr_j=out_hres.stride(2), + stride_hn=out_hist.stride(0) if out_hist.ndim > 1 else 0, + stride_ht=out_hist.stride(1) if out_hist.ndim > 1 else 0, + stride_hi=out_hist.stride(2) if out_hist.ndim > 1 else 0, + stride_hj=out_hist.stride(3) if out_hist.ndim > 1 else 0, + alpha_pre_ptr=alpha_pre.contiguous(), + alpha_post_ptr=alpha_post.contiguous(), + alpha_res_ptr=alpha_res.contiguous(), + pre_eps=pre_eps, + sinkhorn_eps=sinkhorn_eps, + post_mult=post_mult, + TMAX=tmax, + STORE_HIST=return_hist, + num_warps=num_warps, + ) + if return_hist: + return out_hpre, out_hpost, out_hres, out_hist + return out_hpre, out_hpost, out_hres + + +def mhc_sinkhorn_bwd( + mix: torch.Tensor, + b: torch.Tensor, + alpha_res: torch.Tensor, + grad_hres: torch.Tensor, + *, + tmax: int, + sinkhorn_eps: float, + hist: Optional[torch.Tensor] = None, + out_grad_logits: Optional[torch.Tensor] = None, + num_warps: int = 1, +) -> torch.Tensor: + """ + Backward for Sinkhorn: returns grad_logits (same shape as h_res). + + mix: [N, M] float32 + b: [M] float32 + grad_hres: [N, HC, HC] float32 + """ + assert mix.is_contiguous() and b.is_contiguous() and grad_hres.is_contiguous() + + N, M = mix.shape + HC = grad_hres.shape[1] + assert grad_hres.shape == (N, HC, HC) + assert M == HC * HC + 2 * HC + + if out_grad_logits is None: + out_grad_logits = torch.empty((N, HC, HC), device=mix.device, dtype=torch.float32) + + grid = (N,) + + alpha_res_c = alpha_res.contiguous() + + if hist is not None: + assert hist.is_contiguous() + assert hist.shape == (N, tmax, HC, HC) + _mhc_sinkhorn_bwd_hist_kernel[grid]( + mix, + b, + hist, + grad_hres, + out_grad_logits, + N=N, + HC=HC, + stride_mn=mix.stride(0), + stride_mm=mix.stride(1), + stride_hn=hist.stride(0), + stride_ht=hist.stride(1), + stride_hi=hist.stride(2), + stride_hj=hist.stride(3), + stride_go_n=grad_hres.stride(0), + stride_go_i=grad_hres.stride(1), + stride_go_j=grad_hres.stride(2), + stride_gl_n=out_grad_logits.stride(0), + stride_gl_i=out_grad_logits.stride(1), + stride_gl_j=out_grad_logits.stride(2), + alpha_res_ptr=alpha_res_c, + sinkhorn_eps=sinkhorn_eps, + TMAX=tmax, + num_warps=num_warps, + ) + else: + _mhc_sinkhorn_bwd_kernel[grid]( + mix, + b, + grad_hres, + out_grad_logits, + N=N, + HC=HC, + stride_mn=mix.stride(0), + stride_mm=mix.stride(1), + stride_go_n=grad_hres.stride(0), + stride_go_i=grad_hres.stride(1), + stride_go_j=grad_hres.stride(2), + stride_gl_n=out_grad_logits.stride(0), + stride_gl_i=out_grad_logits.stride(1), + stride_gl_j=out_grad_logits.stride(2), + alpha_res_ptr=alpha_res_c, + sinkhorn_eps=sinkhorn_eps, + TMAX=tmax, + num_warps=num_warps, + ) + return out_grad_logits + + +# ------------------------------------------------------------------------------------------------- +# Apply kernels: mhc_pre and mhc_post_res (forward + backward) +# ------------------------------------------------------------------------------------------------- + + +@triton.jit +def _mhc_pre_fwd_kernel( + x_ptr, + hpre_ptr, + out_ptr, + N: tl.constexpr, + HC: tl.constexpr, + C: tl.constexpr, + stride_xn: tl.constexpr, + stride_xh: tl.constexpr, + stride_xc: tl.constexpr, + stride_hn: tl.constexpr, + stride_hh: tl.constexpr, + stride_on: tl.constexpr, + stride_oc: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_C: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_c = tl.program_id(1) + + n_offs = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + c_offs = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) + + acc = tl.zeros((BLOCK_N, BLOCK_C), tl.float32) + for s in tl.static_range(0, HC): + h_s = tl.load( + hpre_ptr + n_offs * stride_hn + s * stride_hh, + mask=(n_offs < N), + other=0.0, + ).to(tl.float32) + xs = tl.load( + x_ptr + n_offs[:, None] * stride_xn + s * stride_xh + c_offs[None, :] * stride_xc, + mask=(n_offs[:, None] < N) & (c_offs[None, :] < C), + other=0.0, + ).to(tl.float32) + acc += xs * h_s[:, None] + + tl.store( + out_ptr + n_offs[:, None] * stride_on + c_offs[None, :] * stride_oc, + acc, + mask=(n_offs[:, None] < N) & (c_offs[None, :] < C), + ) + + +@triton.jit +def _mhc_pre_bwd_kernel( + x_ptr, + hpre_ptr, + grad_out_ptr, + grad_x_ptr, + grad_h_ptr, + N: tl.constexpr, + HC: tl.constexpr, + C: tl.constexpr, + stride_xn: tl.constexpr, + stride_xh: tl.constexpr, + stride_xc: tl.constexpr, + stride_hn: tl.constexpr, + stride_hh: tl.constexpr, + stride_gon: tl.constexpr, + stride_goc: tl.constexpr, + stride_gxn: tl.constexpr, + stride_gxh: tl.constexpr, + stride_gxc: tl.constexpr, + stride_ghn: tl.constexpr, + stride_ghh: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_C: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_c = tl.program_id(1) + + n_offs = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + c_offs = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) + + go = tl.load( + grad_out_ptr + n_offs[:, None] * stride_gon + c_offs[None, :] * stride_goc, + mask=(n_offs[:, None] < N) & (c_offs[None, :] < C), + other=0.0, + ).to(tl.float32) + + # grad_x = grad_out * hpre + for s in tl.static_range(0, HC): + h_s = tl.load( + hpre_ptr + n_offs * stride_hn + s * stride_hh, + mask=(n_offs < N), + other=0.0, + ).to(tl.float32) + gx = go * h_s[:, None] + tl.store( + grad_x_ptr + n_offs[:, None] * stride_gxn + s * stride_gxh + c_offs[None, :] * stride_gxc, + gx, + mask=(n_offs[:, None] < N) & (c_offs[None, :] < C), + ) + + # grad_hpre: dot(go, x_s) over C -> atomic add + xs = tl.load( + x_ptr + n_offs[:, None] * stride_xn + s * stride_xh + c_offs[None, :] * stride_xc, + mask=(n_offs[:, None] < N) & (c_offs[None, :] < C), + other=0.0, + ).to(tl.float32) + part = tl.sum(go * xs, axis=1) + tl.atomic_add( + grad_h_ptr + n_offs * stride_ghn + s * stride_ghh, + part, + mask=n_offs < N, + ) + + +def mhc_pre_fwd( + x: torch.Tensor, + h_pre: torch.Tensor, + *, + out: Optional[torch.Tensor] = None, + block_n: int = 32, + block_c: int = 128, + num_warps: int = 4, +) -> torch.Tensor: + assert x.is_contiguous() and h_pre.is_contiguous() + N, HC, C = x.shape + assert h_pre.shape == (N, HC) + + if out is None: + out = torch.empty((N, C), device=x.device, dtype=torch.float32) + + grid = (triton.cdiv(N, block_n), triton.cdiv(C, block_c)) + _mhc_pre_fwd_kernel[grid]( + x, + h_pre, + out, + N=N, + HC=HC, + C=C, + stride_xn=x.stride(0), + stride_xh=x.stride(1), + stride_xc=x.stride(2), + stride_hn=h_pre.stride(0), + stride_hh=h_pre.stride(1), + stride_on=out.stride(0), + stride_oc=out.stride(1), + BLOCK_N=block_n, + BLOCK_C=block_c, + num_warps=num_warps, + ) + return out + + +def mhc_pre_bwd( + x: torch.Tensor, + h_pre: torch.Tensor, + grad_out: torch.Tensor, + *, + out_grad_x: Optional[torch.Tensor] = None, + out_grad_h: Optional[torch.Tensor] = None, + block_n: int = 32, + block_c: int = 128, + num_warps: int = 4, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.is_contiguous() and h_pre.is_contiguous() and grad_out.is_contiguous() + N, HC, C = x.shape + assert grad_out.shape == (N, C) + + if out_grad_x is None: + out_grad_x = torch.empty_like(x, dtype=torch.float32) + if out_grad_h is None: + out_grad_h = torch.zeros((N, HC), device=x.device, dtype=torch.float32) + + grid = (triton.cdiv(N, block_n), triton.cdiv(C, block_c)) + _mhc_pre_bwd_kernel[grid]( + x, + h_pre, + grad_out, + out_grad_x, + out_grad_h, + N=N, + HC=HC, + C=C, + stride_xn=x.stride(0), + stride_xh=x.stride(1), + stride_xc=x.stride(2), + stride_hn=h_pre.stride(0), + stride_hh=h_pre.stride(1), + stride_gon=grad_out.stride(0), + stride_goc=grad_out.stride(1), + stride_gxn=out_grad_x.stride(0), + stride_gxh=out_grad_x.stride(1), + stride_gxc=out_grad_x.stride(2), + stride_ghn=out_grad_h.stride(0), + stride_ghh=out_grad_h.stride(1), + BLOCK_N=block_n, + BLOCK_C=block_c, + num_warps=num_warps, + ) + return out_grad_x, out_grad_h + + +@triton.jit +def _mhc_post_res_fwd_kernel( + x_ptr, + f_ptr, + hpost_ptr, + hres_ptr, + out_ptr, + N: tl.constexpr, + HC: tl.constexpr, + C: tl.constexpr, + stride_xn: tl.constexpr, + stride_xh: tl.constexpr, + stride_xc: tl.constexpr, + stride_fn: tl.constexpr, + stride_fc: tl.constexpr, + stride_hpn: tl.constexpr, + stride_hph: tl.constexpr, + stride_hrn: tl.constexpr, + stride_hri: tl.constexpr, + stride_hrj: tl.constexpr, + stride_on: tl.constexpr, + stride_oh: tl.constexpr, + stride_oc: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_C: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_c = tl.program_id(1) + + n_offs = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + c_offs = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) + + f = tl.load( + f_ptr + n_offs[:, None] * stride_fn + c_offs[None, :] * stride_fc, + mask=(n_offs[:, None] < N) & (c_offs[None, :] < C), + other=0.0, + ).to(tl.float32) + + o2 = tl.arange(0, HC)[:, None] # [HC,1] + hpost = tl.load( + hpost_ptr + n_offs[None, :] * stride_hpn + o2 * stride_hph, + mask=(n_offs[None, :] < N), + other=0.0, + ).to(tl.float32) # [HC, BN] + + acc = f[None, :, :] * hpost[:, :, None] # [HC, BN, BC] + + # residual mixing: sum_i hres[o,i] * x_i + for i in tl.static_range(0, HC): + xs = tl.load( + x_ptr + n_offs[:, None] * stride_xn + i * stride_xh + c_offs[None, :] * stride_xc, + mask=(n_offs[:, None] < N) & (c_offs[None, :] < C), + other=0.0, + ).to(tl.float32) # [BN, BC] + w = tl.load( + hres_ptr + n_offs[None, :] * stride_hrn + o2 * stride_hri + i * stride_hrj, + mask=(n_offs[None, :] < N), + other=0.0, + ).to(tl.float32) # [HC, BN] + acc += xs[None, :, :] * w[:, :, None] + + o3 = tl.arange(0, HC)[:, None, None] + n3 = n_offs[None, :, None] + c3 = c_offs[None, None, :] + tl.store( + out_ptr + n3 * stride_on + o3 * stride_oh + c3 * stride_oc, + acc, + mask=(n3 < N) & (c3 < C), + ) + + +@triton.jit +def _mhc_post_res_bwd_kernel( + x_ptr, + f_ptr, + hpost_ptr, + hres_ptr, + grad_out_ptr, + grad_x_ptr, + grad_f_ptr, + grad_hpost_ptr, + grad_hres_ptr, + N: tl.constexpr, + HC: tl.constexpr, + C: tl.constexpr, + stride_xn: tl.constexpr, + stride_xh: tl.constexpr, + stride_xc: tl.constexpr, + stride_fn: tl.constexpr, + stride_fc: tl.constexpr, + stride_hpn: tl.constexpr, + stride_hph: tl.constexpr, + stride_hrn: tl.constexpr, + stride_hri: tl.constexpr, + stride_hrj: tl.constexpr, + stride_gon: tl.constexpr, + stride_goh: tl.constexpr, + stride_goc: tl.constexpr, + stride_gxn: tl.constexpr, + stride_gxh: tl.constexpr, + stride_gxc: tl.constexpr, + stride_gfn: tl.constexpr, + stride_gfc: tl.constexpr, + stride_ghpn: tl.constexpr, + stride_ghph: tl.constexpr, + stride_ghrn: tl.constexpr, + stride_ghri: tl.constexpr, + stride_ghrj: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_C: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_c = tl.program_id(1) + + n_offs = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + c_offs = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) + + f = tl.load( + f_ptr + n_offs[:, None] * stride_fn + c_offs[None, :] * stride_fc, + mask=(n_offs[:, None] < N) & (c_offs[None, :] < C), + other=0.0, + ).to(tl.float32) + + o2 = tl.arange(0, HC)[:, None] # [HC,1] + hpost = tl.load( + hpost_ptr + n_offs[None, :] * stride_hpn + o2 * stride_hph, + mask=(n_offs[None, :] < N), + other=0.0, + ).to(tl.float32) # [HC, BN] + + o3 = tl.arange(0, HC)[:, None, None] + n3 = n_offs[None, :, None] + c3 = c_offs[None, None, :] + go = tl.load( + grad_out_ptr + n3 * stride_gon + o3 * stride_goh + c3 * stride_goc, + mask=(n3 < N) & (c3 < C), + other=0.0, + ).to(tl.float32) # [HC, BN, BC] + + # grad_f: sum_o go[o] * hpost[o] + gf = tl.sum(go * hpost[:, :, None], axis=0) + tl.store( + grad_f_ptr + n_offs[:, None] * stride_gfn + c_offs[None, :] * stride_gfc, + gf, + mask=(n_offs[:, None] < N) & (c_offs[None, :] < C), + ) + + # grad_hpost: dot(go[o], f) over C (atomic over C blocks) + part_hpost = tl.sum(go * f[None, :, :], axis=2) # [HC, BN] + tl.atomic_add( + grad_hpost_ptr + n_offs[None, :] * stride_ghpn + o2 * stride_ghph, + part_hpost, + mask=(n_offs[None, :] < N), + ) + + # grad_x: hres^T @ go (in-stream i gets sum_o hres[o,i] * go[o]) + for i in tl.static_range(0, HC): + w = tl.load( + hres_ptr + n_offs[None, :] * stride_hrn + o2 * stride_hri + i * stride_hrj, + mask=(n_offs[None, :] < N), + other=0.0, + ).to(tl.float32) # [HC, BN] + gx = tl.sum(go * w[:, :, None], axis=0) # [BN, BC] + tl.store( + grad_x_ptr + n_offs[:, None] * stride_gxn + i * stride_gxh + c_offs[None, :] * stride_gxc, + gx, + mask=(n_offs[:, None] < N) & (c_offs[None, :] < C), + ) + + # grad_hres[o,i]: dot(go[o], x[i]) over C (atomic) + for i in tl.static_range(0, HC): + xi = tl.load( + x_ptr + n_offs[:, None] * stride_xn + i * stride_xh + c_offs[None, :] * stride_xc, + mask=(n_offs[:, None] < N) & (c_offs[None, :] < C), + other=0.0, + ).to(tl.float32) + part_hres = tl.sum(go * xi[None, :, :], axis=2) # [HC, BN] + tl.atomic_add( + grad_hres_ptr + n_offs[None, :] * stride_ghrn + o2 * stride_ghri + i * stride_ghrj, + part_hres, + mask=(n_offs[None, :] < N), + ) + + +def mhc_post_res_fwd( + x: torch.Tensor, + f_out: torch.Tensor, + h_post: torch.Tensor, + h_res: torch.Tensor, + *, + out: Optional[torch.Tensor] = None, + block_n: Optional[int] = None, + block_c: Optional[int] = None, + num_warps: Optional[int] = None, + num_stages: Optional[int] = None, +) -> torch.Tensor: + assert x.is_contiguous() and f_out.is_contiguous() and h_post.is_contiguous() and h_res.is_contiguous() + + N, HC, C = x.shape + assert f_out.shape == (N, C) + assert h_post.shape == (N, HC) + assert h_res.shape == (N, HC, HC) + + if out is None: + out = torch.empty((N, HC, C), device=x.device, dtype=torch.float32) + + block_n, block_c, num_warps, num_stages = _post_res_meta(C, block_n, block_c, num_warps, num_stages) + + grid = (triton.cdiv(N, block_n), triton.cdiv(C, block_c)) + _mhc_post_res_fwd_kernel[grid]( + x, + f_out, + h_post, + h_res, + out, + N=N, + HC=HC, + C=C, + stride_xn=x.stride(0), + stride_xh=x.stride(1), + stride_xc=x.stride(2), + stride_fn=f_out.stride(0), + stride_fc=f_out.stride(1), + stride_hpn=h_post.stride(0), + stride_hph=h_post.stride(1), + stride_hrn=h_res.stride(0), + stride_hri=h_res.stride(1), + stride_hrj=h_res.stride(2), + stride_on=out.stride(0), + stride_oh=out.stride(1), + stride_oc=out.stride(2), + BLOCK_N=block_n, + BLOCK_C=block_c, + num_warps=num_warps, + num_stages=num_stages, + ) + return out + + +def mhc_post_res_bwd( + x: torch.Tensor, + f_out: torch.Tensor, + h_post: torch.Tensor, + h_res: torch.Tensor, + grad_out: torch.Tensor, + *, + out_grad_x: Optional[torch.Tensor] = None, + out_grad_f: Optional[torch.Tensor] = None, + out_grad_hpost: Optional[torch.Tensor] = None, + out_grad_hres: Optional[torch.Tensor] = None, + block_n: Optional[int] = None, + block_c: Optional[int] = None, + num_warps: Optional[int] = None, + num_stages: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + assert ( + x.is_contiguous() + and f_out.is_contiguous() + and h_post.is_contiguous() + and h_res.is_contiguous() + and grad_out.is_contiguous() + ) + + N, HC, C = x.shape + assert grad_out.shape == (N, HC, C) + + if out_grad_x is None: + out_grad_x = torch.empty_like(x, dtype=torch.float32) + if out_grad_f is None: + out_grad_f = torch.empty_like(f_out, dtype=torch.float32) + if out_grad_hpost is None: + out_grad_hpost = torch.zeros((N, HC), device=x.device, dtype=torch.float32) + if out_grad_hres is None: + out_grad_hres = torch.zeros((N, HC, HC), device=x.device, dtype=torch.float32) + + block_n, block_c, num_warps, num_stages = _post_res_meta(C, block_n, block_c, num_warps, num_stages) + + grid = (triton.cdiv(N, block_n), triton.cdiv(C, block_c)) + _mhc_post_res_bwd_kernel[grid]( + x, + f_out, + h_post, + h_res, + grad_out, + out_grad_x, + out_grad_f, + out_grad_hpost, + out_grad_hres, + N=N, + HC=HC, + C=C, + stride_xn=x.stride(0), + stride_xh=x.stride(1), + stride_xc=x.stride(2), + stride_fn=f_out.stride(0), + stride_fc=f_out.stride(1), + stride_hpn=h_post.stride(0), + stride_hph=h_post.stride(1), + stride_hrn=h_res.stride(0), + stride_hri=h_res.stride(1), + stride_hrj=h_res.stride(2), + stride_gon=grad_out.stride(0), + stride_goh=grad_out.stride(1), + stride_goc=grad_out.stride(2), + stride_gxn=out_grad_x.stride(0), + stride_gxh=out_grad_x.stride(1), + stride_gxc=out_grad_x.stride(2), + stride_gfn=out_grad_f.stride(0), + stride_gfc=out_grad_f.stride(1), + stride_ghpn=out_grad_hpost.stride(0), + stride_ghph=out_grad_hpost.stride(1), + stride_ghrn=out_grad_hres.stride(0), + stride_ghri=out_grad_hres.stride(1), + stride_ghrj=out_grad_hres.stride(2), + BLOCK_N=block_n, + BLOCK_C=block_c, + num_warps=num_warps, + num_stages=num_stages, + ) + return out_grad_x, out_grad_f, out_grad_hpost, out_grad_hres + + +def _flatten_tokens(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Size]: + """ + Flattens leading dimensions so x becomes [N, HC, C]. + Returns (x_flat, x_shape) where x_shape is the original shape. + """ + assert x.dim() >= 3, "x must be [..., HC, C]" + return x.contiguous().view(-1, x.shape[-2], x.shape[-1]), x.shape + + +class LigerMHCCoeffsFunction(torch.autograd.Function): + """ + Autograd function for mHC coefficient computation. + + Memory/Compute Trade-off: + When gradients are needed, Sinkhorn iteration history (hist) is saved + during forward to avoid recomputation in backward. This increases + memory usage by O(N * tmax * HC^2) but reduces backward compute. + """ + + @staticmethod + @ensure_contiguous + def forward( # type: ignore[override] + ctx: Any, + x: torch.Tensor, # [..., HC, C] bf16/fp16 (or fp32 if allow_fp32) + phi: torch.Tensor, # [HC*C, M] + b: torch.Tensor, # [M] + alpha_pre: torch.Tensor, # scalar + alpha_post: torch.Tensor, # scalar + alpha_res: torch.Tensor, # scalar + allow_fp32: bool, + tmax: int, + rms_eps: float, + pre_eps: float, + sinkhorn_eps: float, + post_mult: float, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if allow_fp32: + assert x.dtype in ( + torch.bfloat16, + torch.float16, + torch.float32, + ), "x should be BF16/FP16/FP32 when allow_fp32=True" + else: + assert x.dtype in (torch.bfloat16, torch.float16), "x should be BF16/FP16 (set allow_fp32=True for FP32)" + # Store original shape for restoring at the end + x_shape = x.shape + x_flat, _ = _flatten_tokens(x) + N, HC, C = x_flat.shape + K = HC * C + x_mat = x_flat.view(-1, K) + + assert phi.dim() == 2 and phi.shape[0] == K, f"phi must be [HC*C, M], got {tuple(phi.shape)}" + M = int(phi.shape[1]) + assert b.shape == (M,), f"b must be [M], got {tuple(b.shape)}" + + # (1) fused coeff matmul + norm + mix, invr = mhc_mm_norm_fwd(x_mat, phi, eps=float(rms_eps)) + + # (2) split + sigmoid + sinkhorn + need_hist = any(ctx.needs_input_grad) + if need_hist: + h_pre, h_post, h_res, hist = mhc_split_sinkhorn_fwd( + mix, + b, + alpha_pre, + alpha_post, + alpha_res, + tmax=int(tmax), + pre_eps=float(pre_eps), + sinkhorn_eps=float(sinkhorn_eps), + post_mult=float(post_mult), + return_hist=True, + ) + else: + h_pre, h_post, h_res = mhc_split_sinkhorn_fwd( + mix, + b, + alpha_pre, + alpha_post, + alpha_res, + tmax=int(tmax), + pre_eps=float(pre_eps), + sinkhorn_eps=float(sinkhorn_eps), + post_mult=float(post_mult), + ) + hist = None + + # Save for backward + if hist is not None: + ctx.save_for_backward(x_mat, phi, b, mix, invr, alpha_pre, alpha_post, alpha_res, hist) + else: + ctx.save_for_backward(x_mat, phi, b, mix, invr, alpha_pre, alpha_post, alpha_res) + ctx.meta = ( + x_shape, + HC, + C, + int(tmax), + float(sinkhorn_eps), + float(post_mult), + hist is not None, + ) + + # Reshape to original leading dims + outer = x_shape[:-2] + return ( + h_pre.view(*outer, HC), + h_post.view(*outer, HC), + h_res.view(*outer, HC, HC), + ) + + @staticmethod + @ensure_contiguous + def backward( + ctx: Any, + grad_h_pre: torch.Tensor | None, + grad_h_post: torch.Tensor | None, + grad_h_res: torch.Tensor | None, + ): + saved = ctx.saved_tensors + x_shape, HC, C, tmax, sinkhorn_eps, post_mult, has_hist = ctx.meta + if has_hist: + x_mat, phi, b, mix, invr, alpha_pre, alpha_post, alpha_res, hist = saved + else: + x_mat, phi, b, mix, invr, alpha_pre, alpha_post, alpha_res = saved + hist = None + N = x_mat.shape[0] + M = mix.shape[1] + assert M == HC * HC + 2 * HC + + need_pre = grad_h_pre is not None + need_post = grad_h_post is not None + need_res = grad_h_res is not None + + # flatten grads (None -> zeros) + if need_pre: + gh_pre = grad_h_pre.view(-1, HC).to(torch.float32) + else: + gh_pre = torch.zeros((N, HC), device=mix.device, dtype=torch.float32) + if need_post: + gh_post = grad_h_post.view(-1, HC).to(torch.float32) + else: + gh_post = torch.zeros((N, HC), device=mix.device, dtype=torch.float32) + if need_res: + gh_res = grad_h_res.view(-1, HC, HC).to(torch.float32) + else: + gh_res = torch.zeros((N, HC, HC), device=mix.device, dtype=torch.float32) + + # --- Sinkhorn backward -> grad logits for residual matrix + if need_res: + grad_res_logits = mhc_sinkhorn_bwd( + mix, + b, + alpha_res, + gh_res, + tmax=tmax, + sinkhorn_eps=sinkhorn_eps, + hist=hist, + ) # [N, HC, HC] fp32 + else: + grad_res_logits = gh_res + + # --- Pre/post derivatives (sigmoid) + mix_pre = mix[:, :HC] + mix_post = mix[:, HC : 2 * HC] + mix_res = mix[:, 2 * HC :] + + b_pre = b[:HC] + b_post = b[HC : 2 * HC] + if need_pre: + pre_logits = mix_pre * alpha_pre + b_pre + pre_sig = torch.sigmoid(pre_logits) + grad_pre_logits = gh_pre * (pre_sig * (1.0 - pre_sig)) # [N,HC] + else: + grad_pre_logits = gh_pre + + if need_post: + post_logits = mix_post * alpha_post + b_post + post_sig = torch.sigmoid(post_logits) + grad_post_logits = gh_post * (post_mult * post_sig * (1.0 - post_sig)) # [N,HC] + else: + grad_post_logits = gh_post + + grad_res_logits_flat = grad_res_logits.reshape(N, HC * HC) + + # --- Grad w.r.t mix + grad_mix = torch.empty_like(mix) + grad_mix[:, :HC] = grad_pre_logits * alpha_pre + grad_mix[:, HC : 2 * HC] = grad_post_logits * alpha_post + grad_mix[:, 2 * HC :] = grad_res_logits_flat * alpha_res + + # --- Grad w.r.t b + grad_b = torch.zeros_like(b, dtype=torch.float32) + if need_pre: + grad_b[:HC] = grad_pre_logits.sum(dim=0) + if need_post: + grad_b[HC : 2 * HC] = grad_post_logits.sum(dim=0) + if need_res: + grad_b[2 * HC :] = grad_res_logits_flat.sum(dim=0) + + # --- Grad w.r.t alphas + if need_pre: + grad_alpha_pre = (grad_pre_logits * mix_pre).sum() + else: + grad_alpha_pre = torch.zeros((), device=mix.device, dtype=torch.float32) + if need_post: + grad_alpha_post = (grad_post_logits * mix_post).sum() + else: + grad_alpha_post = torch.zeros((), device=mix.device, dtype=torch.float32) + if need_res: + grad_alpha_res = (grad_res_logits_flat * mix_res).sum() + else: + grad_alpha_res = torch.zeros((), device=mix.device, dtype=torch.float32) + + # --- Grad w.r.t x and phi via fused mm+norm backward + grad_x_mat, grad_phi = mhc_mm_norm_bwd( + x_mat, + phi, + mix, + invr, + grad_mix, + ) + + # Reshape to original shape + grad_x = grad_x_mat.view(x_shape) + + # Return grads for each forward input + return ( + grad_x, # x + grad_phi, # phi + grad_b, # b + grad_alpha_pre, # alpha_pre + grad_alpha_post, # alpha_post + grad_alpha_res, # alpha_res + None, # allow_fp32 + None, # tmax + None, # rms_eps + None, # pre_eps + None, # sinkhorn_eps + None, # post_mult + ) + + +class LigerMHCPreFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward(ctx: Any, x: torch.Tensor, h_pre: torch.Tensor) -> torch.Tensor: + x_shape = x.shape + x_flat, _ = _flatten_tokens(x) + h_pre_flat = h_pre.view(-1, x_flat.shape[1]).to(torch.float32) + out = mhc_pre_fwd(x_flat, h_pre_flat) # [N,C] fp32 + ctx.save_for_backward(x_flat, h_pre_flat) + ctx.x_shape = x_shape + out = out.to(x_flat.dtype) + return out.view(*x_shape[:-2], out.shape[-1]) + + @staticmethod + @ensure_contiguous + def backward(ctx: Any, grad_out: torch.Tensor): + x_flat, h_pre_flat = ctx.saved_tensors + x_shape = ctx.x_shape + N, HC, C = x_flat.shape + go = grad_out.view(-1, C).to(torch.float32) + grad_x, grad_h = mhc_pre_bwd(x_flat, h_pre_flat, go) + grad_x = grad_x.to(x_flat.dtype) + return grad_x.view(*x_shape), grad_h.view(*x_shape[:-1]) + + +class LigerMHCPostResFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward( + ctx: Any, x: torch.Tensor, f_out: torch.Tensor, h_post: torch.Tensor, h_res: torch.Tensor + ) -> torch.Tensor: + x_shape = x.shape + x_flat, _ = _flatten_tokens(x) + N, HC, C = x_flat.shape + f_flat = f_out.view(-1, C) + h_post_flat = h_post.view(-1, HC).to(torch.float32) + h_res_flat = h_res.view(-1, HC, HC).to(torch.float32) + out = mhc_post_res_fwd(x_flat, f_flat, h_post_flat, h_res_flat) # [N,HC,C] fp32 + ctx.save_for_backward(x_flat, f_flat, h_post_flat, h_res_flat) + ctx.x_shape = x_shape + out = out.to(x_flat.dtype) + return out.view(*x_shape) + + @staticmethod + @ensure_contiguous + def backward(ctx: Any, grad_out: torch.Tensor): + x_flat, f_flat, h_post_flat, h_res_flat = ctx.saved_tensors + x_shape = ctx.x_shape + N, HC, C = x_flat.shape + go = grad_out.view(-1, HC, C).to(torch.float32) + + grad_x, grad_f, grad_hpost, grad_hres = mhc_post_res_bwd(x_flat, f_flat, h_post_flat, h_res_flat, go) + + outer = x_shape[:-2] + return ( + grad_x.to(x_flat.dtype).view(*x_shape), + grad_f.to(f_flat.dtype).view(*outer, C), + grad_hpost.view(*outer, HC), + grad_hres.view(*outer, HC, HC), + ) diff --git a/src/liger_kernel/transformers/__init__.py b/src/liger_kernel/transformers/__init__.py index 49e045208..490da1d70 100644 --- a/src/liger_kernel/transformers/__init__.py +++ b/src/liger_kernel/transformers/__init__.py @@ -14,6 +14,7 @@ from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401 from liger_kernel.transformers.llama4_rope import liger_llama4_text_rotary_pos_emb # noqa: F401 from liger_kernel.transformers.llama4_rope import liger_llama4_vision_rotary_pos_emb # noqa: F401 +from liger_kernel.transformers.mhc import LigerMHC # noqa: F401 from liger_kernel.transformers.multi_token_attention import LigerMultiTokenAttention # noqa: F401 from liger_kernel.transformers.poly_norm import LigerPolyNorm # noqa: F401 from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401 @@ -171,6 +172,7 @@ def __getattr__(name: str): "LigerTiledSwiGLUMLP", "LigerTVDLoss", "LigerKLDIVLoss", + "LigerMHC", "LigerMultiTokenAttention", "LigerSoftmax", "LigerSparsemax", diff --git a/src/liger_kernel/transformers/functional.py b/src/liger_kernel/transformers/functional.py index 48437adad..60907bab2 100644 --- a/src/liger_kernel/transformers/functional.py +++ b/src/liger_kernel/transformers/functional.py @@ -14,6 +14,9 @@ from liger_kernel.ops import LigerJSDFunction from liger_kernel.ops import LigerKLDivLossFunction from liger_kernel.ops import LigerLayerNormFunction +from liger_kernel.ops import LigerMHCCoeffsFunction +from liger_kernel.ops import LigerMHCPostResFunction +from liger_kernel.ops import LigerMHCPreFunction from liger_kernel.ops import LigerMultiTokenAttentionFunction from liger_kernel.ops import LigerPolyNormFunction from liger_kernel.ops import LigerQwen2VLMRopeFunction @@ -299,3 +302,100 @@ def liger_softmax(x): def liger_dyt(x, alpha, gamma, beta): return LigerDyTFunction.apply(x, alpha, gamma, beta) + + +def liger_mhc_coeffs( + x, + phi, + b, + alpha_pre, + alpha_post, + alpha_res, + *, + allow_fp32: bool = False, + tmax: int = 20, + rms_eps: float = 1e-6, + pre_eps: float = 0.0, + sinkhorn_eps: float = 1e-6, + post_mult: float = 2.0, +): + # Convert config scalars to Python types so they are not included in the + # autograd computation graph (they are not learnable parameters). + return LigerMHCCoeffsFunction.apply( + x, + phi, + b, + alpha_pre, + alpha_post, + alpha_res, + allow_fp32, + int(tmax), + float(rms_eps), + float(pre_eps), + float(sinkhorn_eps), + float(post_mult), + ) + + +def liger_mhc_pre(x, h_pre): + return LigerMHCPreFunction.apply(x, h_pre) + + +def liger_mhc_post_res(x, f_out, h_post, h_res): + return LigerMHCPostResFunction.apply(x, f_out, h_post, h_res) + + +def liger_mhc_apply(x, f_out, h_pre, h_post, h_res, *, return_x_in: bool = False): + x_in = liger_mhc_pre(x, h_pre) + x_out = liger_mhc_post_res(x, f_out, h_post, h_res) + if return_x_in: + return x_out, x_in + return x_out + + +def liger_mhc_forward( + x, + layer, + phi, + b, + alpha_pre, + alpha_post, + alpha_res, + *, + allow_fp32=False, + tmax=20, + rms_eps=1e-6, + pre_eps=0.0, + sinkhorn_eps=1e-6, + post_mult=2.0, + return_coeffs=False, +): + """High-level helper: compute coeffs, apply pre, run layer, then apply post+res.""" + h_pre, h_post, h_res = liger_mhc_coeffs( + x, + phi, + b, + alpha_pre, + alpha_post, + alpha_res, + allow_fp32=allow_fp32, + tmax=tmax, + rms_eps=rms_eps, + pre_eps=pre_eps, + sinkhorn_eps=sinkhorn_eps, + post_mult=post_mult, + ) + x_in = liger_mhc_pre(x, h_pre) + layer_dtype = x_in.dtype + if hasattr(layer, "parameters"): + try: + layer_dtype = next(layer.parameters()).dtype + except StopIteration: + layer_dtype = x_in.dtype + if x_in.dtype != layer_dtype: + x_in = x_in.to(layer_dtype) + f_out = layer(x_in) + x_out = liger_mhc_post_res(x, f_out, h_post, h_res) + if return_coeffs: + return x_out, (h_pre, h_post, h_res) + return x_out diff --git a/src/liger_kernel/transformers/mhc.py b/src/liger_kernel/transformers/mhc.py new file mode 100644 index 000000000..30459dfbe --- /dev/null +++ b/src/liger_kernel/transformers/mhc.py @@ -0,0 +1,162 @@ +import warnings + +import torch +import torch.nn as nn + +from liger_kernel.transformers.functional import liger_mhc_coeffs +from liger_kernel.transformers.functional import liger_mhc_post_res +from liger_kernel.transformers.functional import liger_mhc_pre + + +class LigerMHC(nn.Module): + """ + Manifold-Constrained Hyper-Connections (mHC) wrapper. + + Wraps an arbitrary layer ``F: [..., C] -> [..., C]`` with multiple residual + streams, following the mHC architecture (arXiv:2512.24880). The input is a + multi-stream tensor of shape ``[..., HC, C]`` where ``HC`` is the number of + residual streams. + + The forward pass performs: + + 1. **Coefficients** -- Compute data-dependent routing coefficients + (``h_pre``, ``h_post``, ``h_res``) via a fused matmul + RMS + normalization + Sinkhorn-Knopp iterations. + 2. **Pre-aggregate** -- ``x_in = sum_i h_pre[i] * x[i]`` + (shape: ``[..., C]``) + 3. **Layer** -- ``f_out = layer(x_in)`` (shape: ``[..., C]``) + 4. **Post + residual** -- + ``x_out[o] = sum_i h_res[o,i] * x[i] + h_post[o] * f_out`` + (shape: ``[..., HC, C]``) + + Args: + layer: The module applied to the aggregated single-stream input. + Must accept ``[..., C]`` and return ``[..., C]``. Common choices + include ``nn.Linear``, attention layers, or MLP blocks. + hc: Number of residual streams (called *n* in the original paper). + Recommended range: [2, 16]. Larger values increase register + pressure and Triton compile time. + c: Per-stream channel dimension. + tmax: Maximum Sinkhorn-Knopp iterations for doubly stochastic + normalization of ``h_res``. Default: 20. + rms_eps: Epsilon for RMS normalization of the projection. + Default: 1e-6. + pre_eps: Additive epsilon for ``h_pre`` after sigmoid. Default: 0.0. + sinkhorn_eps: Epsilon added during Sinkhorn normalization. + Default: 1e-6. + post_mult: Scaling factor for ``h_post`` after sigmoid. Default: 2.0. + phi_dtype: Dtype for the projection matrix ``phi``. Using float16 or + bfloat16 enables Tensor Core acceleration. Default: torch.float16. + allow_fp32: If True, accept FP32 input tensors. Note that FP32 mode + does **not** use Tensor Cores and will be slower. Default: False. + + Learnable Parameters: + - **phi** ``[HC*C, HC*HC + 2*HC]`` -- Projection matrix for computing + routing coefficients from flattened stream tokens. + - **b** ``[HC*HC + 2*HC]`` -- Bias for routing logits (float32). + - **alpha_pre** (scalar) -- Scales pre-routing logits before sigmoid. + - **alpha_post** (scalar) -- Scales post-routing logits before sigmoid. + - **alpha_res** (scalar) -- Scales residual logits before Sinkhorn. + + Example:: + + import torch + import torch.nn as nn + from liger_kernel.transformers import LigerMHC + + # Wrap a linear layer with 4 residual streams of dimension 256 + layer = nn.Linear(256, 256, bias=False, device="cuda", dtype=torch.bfloat16) + mhc = LigerMHC(layer, hc=4, c=256, phi_dtype=torch.bfloat16).cuda() + + # Input: [batch, seq_len, num_streams, channels] + x = torch.randn(2, 128, 4, 256, device="cuda", dtype=torch.bfloat16) + out = mhc(x) # shape: [2, 128, 4, 256] + + # In a transformer block (pseudocode): + # x = mhc_attn(x) # attention wrapped in LigerMHC + # x = mhc_ffn(x) # FFN wrapped in LigerMHC + """ + + def __init__( + self, + layer: nn.Module, + *, + hc: int, + c: int, + tmax: int = 20, + rms_eps: float = 1e-6, + pre_eps: float = 0.0, + sinkhorn_eps: float = 1e-6, + post_mult: float = 2.0, + phi_dtype: torch.dtype = torch.float16, + allow_fp32: bool = False, + ): + super().__init__() + self.layer = layer + # hc: number of residual streams (n in the paper) + self.hc = int(hc) + self.c = int(c) + + if hc > 16: + warnings.warn( + f"hc={hc} exceeds recommended range [2, 16]. " + "Large values may cause register pressure and increased compile time.", + stacklevel=2, + ) + self.tmax = int(tmax) + self.rms_eps = float(rms_eps) + self.pre_eps = float(pre_eps) + self.sinkhorn_eps = float(sinkhorn_eps) + self.post_mult = float(post_mult) + self.allow_fp32 = bool(allow_fp32) + + m = hc * hc + 2 * hc + k = hc * c + + try: + layer_device = next(self.layer.parameters()).device + except StopIteration: + layer_device = torch.device("cpu") + + # Note: for best speed, keep phi in BF16/FP16 to enable tensor-core matmul in Triton. + self.phi = nn.Parameter(torch.randn(k, m, dtype=phi_dtype, device=layer_device) * 0.02) + self.b = nn.Parameter(torch.zeros(m, dtype=torch.float32, device=layer_device)) + self.alpha_pre = nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=layer_device)) + self.alpha_post = nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=layer_device)) + self.alpha_res = nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=layer_device)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + x: [..., HC, C] (BF16/FP16 recommended; FP32 allowed if allow_fp32=True) + returns: [..., HC, C] + """ + if x.shape[-2] != self.hc or x.shape[-1] != self.c: + raise ValueError(f"Expected x.shape[-2:]=[{self.hc}, {self.c}], got {list(x.shape[-2:])}") + + h_pre, h_post, h_res = liger_mhc_coeffs( + x, + self.phi, + self.b, + self.alpha_pre, + self.alpha_post, + self.alpha_res, + allow_fp32=self.allow_fp32, + tmax=self.tmax, + rms_eps=self.rms_eps, + pre_eps=self.pre_eps, + sinkhorn_eps=self.sinkhorn_eps, + post_mult=self.post_mult, + ) + x_in = liger_mhc_pre(x, h_pre) # [..., C] + layer_dtype = x_in.dtype + for param in self.layer.parameters(recurse=True): + layer_dtype = param.dtype + break + if x_in.dtype != layer_dtype: + x_in = x_in.to(layer_dtype) + f_out = self.layer(x_in) # [..., C] + x_out = liger_mhc_post_res(x, f_out, h_post, h_res) # [..., HC, C] + return x_out + + def extra_repr(self) -> str: + return f"hc={self.hc}, c={self.c}, tmax={self.tmax}" diff --git a/test/transformers/test_mhc.py b/test/transformers/test_mhc.py new file mode 100644 index 000000000..6c8b273f6 --- /dev/null +++ b/test/transformers/test_mhc.py @@ -0,0 +1,522 @@ +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F + +from test.utils import assert_verbose_allclose +from test.utils import infer_device +from test.utils import set_seed +from test.utils import supports_bfloat16 + +from liger_kernel.transformers.functional import liger_mhc_coeffs +from liger_kernel.transformers.functional import liger_mhc_post_res +from liger_kernel.transformers.functional import liger_mhc_pre +from liger_kernel.transformers.mhc import LigerMHC + +device = infer_device() + +MHC_SHAPES = [ + (2, 4, 2, 32), + (1, 8, 4, 64), +] + +MHC_DTYPE_TOLS = [ + (torch.float16, 8e-3, 1.5e-2, 2e-2), + pytest.param( + torch.bfloat16, + 1.5e-2, + 2.5e-2, + 5e-2, + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported"), + ), +] + +MHC_COEFFS_DTYPE_TOLS = MHC_DTYPE_TOLS + [ + (torch.float32, 5e-4, 1e-3, 2e-3), +] + + +def mhc_sinkhorn_ref(logits: torch.Tensor, *, tmax: int, eps: float) -> torch.Tensor: + """ + logits: [N, HC, HC] + """ + mat = torch.softmax(logits, dim=-1) + eps + mat = mat / (mat.sum(dim=-2, keepdim=True) + eps) + for _ in range(tmax - 1): + mat = mat / (mat.sum(dim=-1, keepdim=True) + eps) + mat = mat / (mat.sum(dim=-2, keepdim=True) + eps) + return mat + + +def mhc_coeffs_ref( + x: torch.Tensor, + phi: torch.Tensor, + b: torch.Tensor, + alpha_pre: torch.Tensor, + alpha_post: torch.Tensor, + alpha_res: torch.Tensor, + *, + tmax: int, + rms_eps: float, + pre_eps: float, + sinkhorn_eps: float, + post_mult: float, +): + x_flat = x.contiguous().view(-1, x.shape[-2], x.shape[-1]).float() + n, hc, c = x_flat.shape + k = hc * c + x_mat = x_flat.view(n, k) + invr = torch.rsqrt(x_mat.pow(2).mean(dim=-1, keepdim=True) + rms_eps) + mix = (x_mat @ phi.float()) * invr + + pre_logits = mix[:, :hc] * alpha_pre + b[:hc] + post_logits = mix[:, hc : 2 * hc] * alpha_post + b[hc : 2 * hc] + res_logits = mix[:, 2 * hc :].view(n, hc, hc) * alpha_res + b[2 * hc :].view(hc, hc) + + h_pre = torch.sigmoid(pre_logits) + pre_eps + h_post = torch.sigmoid(post_logits) * post_mult + h_res = mhc_sinkhorn_ref(res_logits, tmax=tmax, eps=sinkhorn_eps) + + outer = x.shape[:-2] + return ( + h_pre.view(*outer, hc), + h_post.view(*outer, hc), + h_res.view(*outer, hc, hc), + ) + + +@pytest.mark.parametrize("B, T, HC, C", MHC_SHAPES) +@pytest.mark.parametrize("phi_dtype", [torch.float16, torch.float32]) +@pytest.mark.parametrize("dtype, pre_post_tol, res_tol, grad_tol", MHC_COEFFS_DTYPE_TOLS) +def test_mhc_coeffs_forward_backward(B, T, HC, C, phi_dtype, dtype, pre_post_tol, res_tol, grad_tol): + set_seed(42) + K = HC * C + M = HC * HC + 2 * HC + + allow_fp32 = dtype == torch.float32 + if allow_fp32: + phi_dtype = torch.float32 + + x = torch.randn(B, T, HC, C, device=device, dtype=dtype, requires_grad=True) + phi = (torch.randn(K, M, device=device, dtype=phi_dtype) * 0.02).requires_grad_(True) + b = torch.zeros(M, device=device, dtype=torch.float32, requires_grad=True) + alpha_pre = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=True) + alpha_post = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=True) + alpha_res = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=True) + + cfg = dict(tmax=8, rms_eps=1e-6, pre_eps=1e-4, sinkhorn_eps=1e-6, post_mult=2.0) + + h_pre, h_post, h_res = liger_mhc_coeffs(x, phi, b, alpha_pre, alpha_post, alpha_res, allow_fp32=allow_fp32, **cfg) + + loss = h_pre.square().mean() + h_post.square().mean() + h_res.square().mean() + loss.backward() + + grads_triton = ( + x.grad.detach().float().clone(), + phi.grad.detach().float().clone(), + b.grad.detach().float().clone(), + alpha_pre.grad.detach().float().clone(), + alpha_post.grad.detach().float().clone(), + alpha_res.grad.detach().float().clone(), + ) + + x2 = x.detach().clone().requires_grad_(True) + phi2 = phi.detach().clone().requires_grad_(True) + b2 = b.detach().clone().requires_grad_(True) + ap2 = alpha_pre.detach().clone().requires_grad_(True) + apo2 = alpha_post.detach().clone().requires_grad_(True) + ar2 = alpha_res.detach().clone().requires_grad_(True) + + rh_pre, rh_post, rh_res = mhc_coeffs_ref(x2, phi2, b2, ap2, apo2, ar2, **cfg) + rloss = rh_pre.square().mean() + rh_post.square().mean() + rh_res.square().mean() + rloss.backward() + + grads_ref = ( + x2.grad.detach().float(), + phi2.grad.detach().float(), + b2.grad.detach().float(), + ap2.grad.detach().float(), + apo2.grad.detach().float(), + ar2.grad.detach().float(), + ) + + assert_verbose_allclose(h_pre.float(), rh_pre.float(), rtol=pre_post_tol, atol=pre_post_tol) + assert_verbose_allclose(h_post.float(), rh_post.float(), rtol=pre_post_tol, atol=pre_post_tol) + assert_verbose_allclose(h_res.float(), rh_res.float(), rtol=res_tol, atol=res_tol) + + for gt, gr in zip(grads_triton, grads_ref): + assert_verbose_allclose(gt, gr, rtol=grad_tol, atol=grad_tol) + + +def test_mhc_coeffs_disallow_fp32(): + B, T, HC, C = 1, 2, 2, 8 + K = HC * C + M = HC * HC + 2 * HC + + x = torch.randn(B, T, HC, C, device=device, dtype=torch.float32) + phi = torch.randn(K, M, device=device, dtype=torch.float32) + b = torch.zeros(M, device=device, dtype=torch.float32) + alpha_pre = torch.tensor(1.0, device=device, dtype=torch.float32) + alpha_post = torch.tensor(1.0, device=device, dtype=torch.float32) + alpha_res = torch.tensor(1.0, device=device, dtype=torch.float32) + + with pytest.raises(AssertionError): + _ = liger_mhc_coeffs(x, phi, b, alpha_pre, alpha_post, alpha_res) + + +@pytest.mark.parametrize("B, T, HC, C", MHC_SHAPES) +@pytest.mark.parametrize( + "use_pre,use_post,use_res", + [ + (True, False, False), + (False, True, False), + (False, False, True), + ], +) +def test_mhc_coeffs_backward_allows_unused_outputs(B, T, HC, C, use_pre, use_post, use_res): + set_seed(42) + K = HC * C + M = HC * HC + 2 * HC + + x = torch.randn(B, T, HC, C, device=device, dtype=torch.float16, requires_grad=True) + phi = (torch.randn(K, M, device=device, dtype=torch.float16) * 0.02).requires_grad_(True) + b = torch.zeros(M, device=device, dtype=torch.float32, requires_grad=True) + alpha_pre = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=True) + alpha_post = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=True) + alpha_res = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=True) + + cfg = dict(tmax=4, rms_eps=1e-6, pre_eps=1e-4, sinkhorn_eps=1e-6, post_mult=2.0) + + h_pre, h_post, h_res = liger_mhc_coeffs(x, phi, b, alpha_pre, alpha_post, alpha_res, **cfg) + + loss = torch.zeros((), device=device) + if use_pre: + loss = loss + h_pre.square().mean() + if use_post: + loss = loss + h_post.square().mean() + if use_res: + loss = loss + h_res.square().mean() + loss.backward() + + for tensor in (x, phi, b, alpha_pre, alpha_post, alpha_res): + assert tensor.grad is not None + + +@pytest.mark.parametrize("B, T, HC, C", MHC_SHAPES) +@pytest.mark.parametrize("dtype, pre_post_tol, res_tol, grad_tol", MHC_DTYPE_TOLS) +def test_mhc_pre_and_post_res_match_reference(B, T, HC, C, dtype, pre_post_tol, res_tol, grad_tol): + set_seed(42) + + # Liger path + x1 = torch.randn(B, T, HC, C, device=device, dtype=dtype, requires_grad=True) + h_pre1 = torch.rand(B, T, HC, device=device, dtype=torch.float32, requires_grad=True) + h_post1 = torch.rand(B, T, HC, device=device, dtype=torch.float32, requires_grad=True) + h_res1 = torch.rand(B, T, HC, HC, device=device, dtype=torch.float32, requires_grad=True) + f_out1 = torch.randn(B, T, C, device=device, dtype=dtype, requires_grad=True) + + x_in = liger_mhc_pre(x1, h_pre1) + x_out = liger_mhc_post_res(x1, f_out1, h_post1, h_res1) + + # Reference path (clone inputs for independent computation graph) + x2 = x1.detach().clone().requires_grad_(True) + h_pre2 = h_pre1.detach().clone().requires_grad_(True) + h_post2 = h_post1.detach().clone().requires_grad_(True) + h_res2 = h_res1.detach().clone().requires_grad_(True) + f_out2 = f_out1.detach().clone().requires_grad_(True) + + x_in_ref = (x2.float() * h_pre2.unsqueeze(-1)).sum(dim=-2) + x_out_ref = torch.einsum("...oi,...ic->...oc", h_res2, x2.float()) + h_post2.unsqueeze( + -1 + ) * f_out2.float().unsqueeze(-2) + + # Forward check + assert_verbose_allclose(x_in.float(), x_in_ref, rtol=pre_post_tol, atol=pre_post_tol) + assert_verbose_allclose(x_out.float(), x_out_ref, rtol=res_tol, atol=res_tol) + + # Backward check + loss = x_in.square().mean() + x_out.square().mean() + loss.backward() + + loss_ref = x_in_ref.square().mean() + x_out_ref.square().mean() + loss_ref.backward() + + assert_verbose_allclose(x1.grad.float(), x2.grad.float(), rtol=grad_tol, atol=grad_tol) + assert_verbose_allclose(h_pre1.grad.float(), h_pre2.grad.float(), rtol=grad_tol, atol=grad_tol) + assert_verbose_allclose(h_post1.grad.float(), h_post2.grad.float(), rtol=grad_tol, atol=grad_tol) + assert_verbose_allclose(h_res1.grad.float(), h_res2.grad.float(), rtol=grad_tol, atol=grad_tol) + assert_verbose_allclose(f_out1.grad.float(), f_out2.grad.float(), rtol=grad_tol, atol=grad_tol) + + +@pytest.mark.parametrize("B, T, HC, C", MHC_SHAPES) +@pytest.mark.parametrize("dtype, pre_post_tol, res_tol, grad_tol", MHC_DTYPE_TOLS) +def test_liger_mhc_functional(B, T, HC, C, dtype, pre_post_tol, res_tol, grad_tol): + set_seed(42) + K = HC * C + M = HC * HC + 2 * HC + + x = torch.randn(B, T, HC, C, device=device, dtype=dtype, requires_grad=True) + phi = (torch.randn(K, M, device=device, dtype=dtype) * 0.02).requires_grad_(True) + b = torch.zeros(M, device=device, dtype=torch.float32, requires_grad=True) + alpha_pre = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=True) + alpha_post = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=True) + alpha_res = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=True) + + cfg = dict(tmax=4, rms_eps=1e-6, pre_eps=1e-4, sinkhorn_eps=1e-6, post_mult=2.0) + + h_pre, h_post, h_res = liger_mhc_coeffs(x, phi, b, alpha_pre, alpha_post, alpha_res, **cfg) + rh_pre, rh_post, rh_res = mhc_coeffs_ref(x, phi, b, alpha_pre, alpha_post, alpha_res, **cfg) + + assert_verbose_allclose(h_pre.float(), rh_pre.float(), rtol=pre_post_tol, atol=pre_post_tol, extra_info="[h_pre]") + assert_verbose_allclose( + h_post.float(), rh_post.float(), rtol=pre_post_tol, atol=pre_post_tol, extra_info="[h_post]" + ) + assert_verbose_allclose(h_res.float(), rh_res.float(), rtol=res_tol, atol=res_tol, extra_info="[h_res]") + + loss = h_pre.square().mean() + h_post.square().mean() + h_res.square().mean() + loss.backward() + + x2 = x.detach().clone().requires_grad_(True) + phi2 = phi.detach().clone().requires_grad_(True) + b2 = b.detach().clone().requires_grad_(True) + ap2 = alpha_pre.detach().clone().requires_grad_(True) + apo2 = alpha_post.detach().clone().requires_grad_(True) + ar2 = alpha_res.detach().clone().requires_grad_(True) + rh_pre2, rh_post2, rh_res2 = mhc_coeffs_ref(x2, phi2, b2, ap2, apo2, ar2, **cfg) + rloss = rh_pre2.square().mean() + rh_post2.square().mean() + rh_res2.square().mean() + rloss.backward() + + assert_verbose_allclose(x.grad.float(), x2.grad.float(), rtol=grad_tol, atol=grad_tol, extra_info="[x.grad]") + assert_verbose_allclose(phi.grad.float(), phi2.grad.float(), rtol=grad_tol, atol=grad_tol, extra_info="[phi.grad]") + assert_verbose_allclose(b.grad.float(), b2.grad.float(), rtol=grad_tol, atol=grad_tol, extra_info="[b.grad]") + assert_verbose_allclose( + alpha_pre.grad.float(), ap2.grad.float(), rtol=grad_tol, atol=grad_tol, extra_info="[alpha_pre]" + ) + assert_verbose_allclose( + alpha_post.grad.float(), apo2.grad.float(), rtol=grad_tol, atol=grad_tol, extra_info="[alpha_post]" + ) + assert_verbose_allclose( + alpha_res.grad.float(), ar2.grad.float(), rtol=grad_tol, atol=grad_tol, extra_info="[alpha_res]" + ) + + x3 = x.detach().clone().requires_grad_(True) + h_pre3 = h_pre.detach().clone().requires_grad_(True) + h_post3 = h_post.detach().clone().requires_grad_(True) + h_res3 = h_res.detach().clone().requires_grad_(True) + f_out = torch.randn(B, T, C, device=device, dtype=dtype, requires_grad=True) + + x_in = liger_mhc_pre(x3, h_pre3) + x_out = liger_mhc_post_res(x3, f_out, h_post3, h_res3) + + x_in_ref = (x3.float() * h_pre3.unsqueeze(-1)).sum(dim=-2) + x_out_ref = torch.einsum("...oi,...ic->...oc", h_res3, x3.float()) + h_post3.unsqueeze( + -1 + ) * f_out.float().unsqueeze(-2) + + assert_verbose_allclose(x_in.float(), x_in_ref, rtol=pre_post_tol, atol=pre_post_tol, extra_info="[x_in]") + assert_verbose_allclose(x_out.float(), x_out_ref, rtol=res_tol, atol=res_tol, extra_info="[x_out]") + + +@pytest.mark.parametrize("B, T, HC, C", MHC_SHAPES) +@pytest.mark.parametrize("dtype, _pre_post_tol, res_tol, grad_tol", MHC_DTYPE_TOLS) +def test_liger_mhc_module(B, T, HC, C, dtype, _pre_post_tol, res_tol, grad_tol): + set_seed(42) + + layer = nn.Linear(C, C, bias=False, device=device, dtype=dtype) + model = LigerMHC( + layer, + hc=HC, + c=C, + tmax=4, + rms_eps=1e-6, + pre_eps=1e-4, + sinkhorn_eps=1e-6, + post_mult=2.0, + phi_dtype=dtype, + ).to(device) + + x_fast = torch.randn(B, T, HC, C, device=device, dtype=dtype, requires_grad=True) + out_fast = model(x_fast) + + x_ref = x_fast.detach().clone().requires_grad_(True) + phi_ref = model.phi.detach().clone().requires_grad_(True) + b_ref = model.b.detach().clone().requires_grad_(True) + ap_ref = model.alpha_pre.detach().clone().requires_grad_(True) + apo_ref = model.alpha_post.detach().clone().requires_grad_(True) + ar_ref = model.alpha_res.detach().clone().requires_grad_(True) + + layer_ref = nn.Linear(C, C, bias=False, device=device, dtype=dtype) + layer_ref.weight.data.copy_(model.layer.weight.data) + + h_pre, h_post, h_res = mhc_coeffs_ref( + x_ref, + phi_ref, + b_ref, + ap_ref, + apo_ref, + ar_ref, + tmax=4, + rms_eps=1e-6, + pre_eps=1e-4, + sinkhorn_eps=1e-6, + post_mult=2.0, + ) + x_in_ref = (x_ref.float() * h_pre.unsqueeze(-1)).sum(dim=-2).to(dtype) + f_out_ref = layer_ref(x_in_ref) + out_ref = torch.einsum("...oi,...ic->...oc", h_res, x_ref.float()) + h_post.unsqueeze( + -1 + ) * f_out_ref.float().unsqueeze(-2) + + assert_verbose_allclose(out_fast.float(), out_ref.float(), rtol=res_tol, atol=res_tol, extra_info="[output]") + + grad = torch.randn_like(out_fast, dtype=torch.float32) + out_fast.backward(grad.to(out_fast.dtype)) + out_ref.backward(grad) + + assert_verbose_allclose( + x_fast.grad.float(), x_ref.grad.float(), rtol=grad_tol, atol=grad_tol, extra_info="[x.grad]" + ) + phi_grad_tol = grad_tol * 4 if dtype == torch.bfloat16 else grad_tol + assert_verbose_allclose( + model.phi.grad.float(), + phi_ref.grad.float(), + rtol=phi_grad_tol, + atol=phi_grad_tol, + extra_info="[phi.grad]", + ) + assert_verbose_allclose( + model.b.grad.float(), b_ref.grad.float(), rtol=grad_tol, atol=grad_tol, extra_info="[b.grad]" + ) + assert_verbose_allclose( + model.alpha_pre.grad.float(), ap_ref.grad.float(), rtol=grad_tol, atol=grad_tol, extra_info="[alpha_pre.grad]" + ) + assert_verbose_allclose( + model.alpha_post.grad.float(), + apo_ref.grad.float(), + rtol=grad_tol, + atol=grad_tol, + extra_info="[alpha_post.grad]", + ) + assert_verbose_allclose( + model.alpha_res.grad.float(), ar_ref.grad.float(), rtol=grad_tol, atol=grad_tol, extra_info="[alpha_res.grad]" + ) + layer_grad_tol = grad_tol * 4 if dtype == torch.bfloat16 else grad_tol + assert_verbose_allclose( + model.layer.weight.grad.float(), + layer_ref.weight.grad.float(), + rtol=layer_grad_tol, + atol=layer_grad_tol, + extra_info="[layer.weight.grad]", + ) + + +class MiniMHCLM(nn.Module): + """Tiny language model using mHC for end-to-end correctness testing.""" + + def __init__(self, *, vocab_size, hc, c, tmax, rms_eps, pre_eps, sinkhorn_eps, post_mult, use_fast, device): + super().__init__() + self.vocab_size = vocab_size + self.hc = hc + self.c = c + self.tmax = tmax + self.rms_eps = rms_eps + self.pre_eps = pre_eps + self.sinkhorn_eps = sinkhorn_eps + self.post_mult = post_mult + self.use_fast = use_fast + self.act_dtype = torch.bfloat16 + + self.embed = nn.Embedding(vocab_size, hc * c, device=device) + self.inner = nn.Linear(c, c, bias=False, device=device) + self.head = nn.Linear(hc * c, vocab_size, bias=False, device=device) + + m = hc * hc + 2 * hc + k = hc * c + self.phi = nn.Parameter(torch.randn(k, m, device=device, dtype=self.act_dtype) * 0.02) + self.b = nn.Parameter(torch.zeros(m, device=device, dtype=torch.float32)) + self.alpha_pre = nn.Parameter(torch.tensor(1.0, device=device, dtype=torch.float32)) + self.alpha_post = nn.Parameter(torch.tensor(1.0, device=device, dtype=torch.float32)) + self.alpha_res = nn.Parameter(torch.tensor(1.0, device=device, dtype=torch.float32)) + + def forward(self, input_ids): + x = self.embed(input_ids).to(self.act_dtype) + bsz, seq_len, _ = x.shape + x = x.view(bsz, seq_len, self.hc, self.c) + + cfg = dict( + tmax=self.tmax, + rms_eps=self.rms_eps, + pre_eps=self.pre_eps, + sinkhorn_eps=self.sinkhorn_eps, + post_mult=self.post_mult, + ) + if self.use_fast: + h_pre, h_post, h_res = liger_mhc_coeffs( + x, self.phi, self.b, self.alpha_pre, self.alpha_post, self.alpha_res, **cfg + ) + x_in = liger_mhc_pre(x, h_pre) + f_out = self.inner(x_in.float()) + x_out = liger_mhc_post_res(x, f_out, h_post, h_res) + else: + h_pre, h_post, h_res = mhc_coeffs_ref( + x, self.phi, self.b, self.alpha_pre, self.alpha_post, self.alpha_res, **cfg + ) + x_in = (x.float() * h_pre.unsqueeze(-1)).sum(dim=-2) + f_out = self.inner(x_in) + x_out = torch.einsum("...oi,...ic->...oc", h_res, x.float()) + h_post.unsqueeze(-1) * f_out.unsqueeze(-2) + + x_merge = x_out.float().view(bsz, seq_len, self.hc * self.c) + return self.head(x_merge) + + +@pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU") +@pytest.mark.parametrize( + "vocab_size, hc, c, tmax", + [ + (32, 2, 16, 4), + (64, 4, 32, 8), + ], +) +def test_mhc_mini_lm_output_match(vocab_size, hc, c, tmax): + set_seed(42) + + model_cfg = dict( + vocab_size=vocab_size, hc=hc, c=c, tmax=tmax, rms_eps=1e-6, pre_eps=1e-4, sinkhorn_eps=1e-6, post_mult=2.0 + ) + + model_fast = MiniMHCLM(**model_cfg, use_fast=True, device=device) + model_ref = MiniMHCLM(**model_cfg, use_fast=False, device=device) + model_ref.load_state_dict(model_fast.state_dict()) + + input_ids = torch.randint(0, vocab_size, (2, 8), device=device) + labels = torch.randint(0, vocab_size, (2, 8), device=device) + + logits_fast = model_fast(input_ids) + logits_ref = model_ref(input_ids) + + assert_verbose_allclose(logits_fast.float(), logits_ref.float(), atol=5e-3, rtol=2e-2, extra_info="[logits]") + + loss_fast = F.cross_entropy(logits_fast.view(-1, vocab_size), labels.view(-1)) + loss_ref = F.cross_entropy(logits_ref.view(-1, vocab_size), labels.view(-1)) + + loss_fast.backward() + loss_ref.backward() + + for name in ["phi", "b", "alpha_pre", "alpha_post", "alpha_res"]: + g_fast = getattr(model_fast, name).grad.float() + g_ref = getattr(model_ref, name).grad.float() + assert_verbose_allclose(g_fast, g_ref, atol=5e-2, rtol=5e-2, extra_info=f"[{name}.grad]") + + assert_verbose_allclose( + model_fast.inner.weight.grad.float(), + model_ref.inner.weight.grad.float(), + atol=5e-2, + rtol=5e-2, + extra_info="[inner.weight.grad]", + ) + assert_verbose_allclose( + model_fast.head.weight.grad.float(), + model_ref.head.weight.grad.float(), + atol=5e-2, + rtol=5e-2, + extra_info="[head.weight.grad]", + )