From 672c8c8ad42d82d994e35ec2d758d2b83018abec Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Mon, 3 Nov 2025 20:22:04 +0800 Subject: [PATCH 01/27] feat(FLCE): add helion version of fused linear cross entropy --- .../ops/helion/fused_linear_cross_entropy.py | 227 ++++++++++++++++++ 1 file changed, 227 insertions(+) create mode 100644 src/liger_kernel/ops/helion/fused_linear_cross_entropy.py diff --git a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py new file mode 100644 index 000000000..562dae132 --- /dev/null +++ b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py @@ -0,0 +1,227 @@ +import helion +import helion.language as hl +import torch + + +@helion.kernel(autotune_effort="none") +def fused_linear_cross_entropy_fwd( + x: torch.Tensor, + weight: torch.Tensor, + target: torch.Tensor, + ignore_index: int = -100, + reduction: str = "mean", +) -> torch.Tensor: + """ + Performs matrix multiplication followed by cross entropy loss. + Args: + x: input tensor of shape [BT, H] + weight: weight tensor of shape [V, H] + target: target tensor of shape [BT,] + ignore_index: index to ignore in the target + reduction: reduction to apply to the loss + Returns: + loss: loss tensor of shape [1] + """ + BT, H = x.size() + V = weight.size(0) + block_size_bt = hl.register_block_size(BT) + block_size_h = hl.register_block_size(H) + block_size_v = hl.register_block_size(V) + + logits = torch.empty(BT, V, device=x.device, dtype=torch.float32) + lse = torch.full((BT,), fill_value=-torch.inf, device=x.device, dtype=torch.float32) # DEBUG + nll = torch.zeros(BT, device=x.device, dtype=torch.float32) + neg_target_logits = torch.zeros(BT, device=x.device, dtype=torch.float32) + + for tile_bt in hl.tile(BT, block_size=block_size_bt): + m_i = hl.zeros([tile_bt], dtype=torch.float32) - float("inf") + d_i = hl.zeros([tile_bt], dtype=torch.float32) + nll_tile = hl.zeros([tile_bt], dtype=torch.float32) + # target_indices = target[tile_bt][:, None] # [tile_bt, 1] # ERROR + target_indices = target[tile_bt].unsqueeze(1) # [tile_bt, 1] + for tile_v in hl.tile(V, block_size=block_size_v): + # logits computation + acc = hl.zeros([tile_bt, tile_v], dtype=torch.float32) + for tile_h in hl.tile(H, block_size=block_size_h): + x_tile = x[tile_bt, tile_h] + weight_tile = weight[tile_v, tile_h] + acc = hl.dot(x_tile, weight_tile.T, acc=acc, out_dtype=torch.float32) + + logits[tile_bt, tile_v] = acc # DEBUG + + # online softmax statistics + m_ij = torch.maximum(m_i, torch.amax(acc, dim=-1)) + d_i = d_i * torch.exp(m_i - m_ij) + torch.exp(acc - m_ij[:, None]).sum(dim=-1) + m_i = m_ij + + # offset = tile_v.index[None, :] # [1, tile_v] # ERROR + offset = tile_v.index.unsqueeze(0) # [1, tile_v] + mask = target_indices == offset # [tile_bt, tile_v] + nll_tile += torch.sum(-acc * mask, dim=-1) # [tile_bt] + + # loss computation: -logsoftmax(x_y) = -log(exp(x_y) / sum(exp(x_i))) = -x_y + log(sum(exp(x_i))) + lse_tile = m_i + torch.log(d_i) + lse[tile_bt] = lse_tile + + neg_target_logits[tile_bt] = nll_tile + + nll_tile = nll_tile + lse_tile + nll[tile_bt] = nll_tile + + + if reduction == "mean": + loss = nll.sum() / nll.numel() + elif reduction == "sum": + loss = nll.sum() + else: + loss = nll + + return loss, logits.to(x.dtype), lse, neg_target_logits + + +# class LigerFusedLinearCrossEntropyHelionFunction(torch.autograd.Function): +# @staticmethod +# def forward( +# ctx, +# _input, +# weight, +# target, +# ignore_index=-100, +# reduction="mean", +# ): +# assert _input.ndim == weight.ndim +# loss, grad_input, grad_weight = fused_linear_cross_entropy_fwd_bwd( +# _input, +# weight, +# target, +# ignore_index, +# reduction, +# ) +# ctx.save_for_backward(grad_input, grad_weight) +# return loss + +# @staticmethod +# def backward(ctx, grad_output): +# grad_input, grad_weight = ctx.saved_tensors +# return grad_input * grad_output, grad_weight * grad_output, None, None, None + + +class LigerFusedLinearCrossEntropyHelion(torch.nn.Module): + def __init__(self, ignore_index=-100, reduction="mean"): + super().__init__() + self.ignore_index = ignore_index + self.reduction = reduction + + def forward(self, _input, weight, target): + # return LigerFusedLinearCrossEntropyHelionFunction.apply( + # _input, + # weight, + # target, + # self.ignore_index, + # self.reduction + # ) + return fused_linear_cross_entropy_fwd(_input, weight, target, self.ignore_index, self.reduction) + + +class TorchLMHeadCE(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + ignore_index: int = -100, + reduction: str = "mean", + ): + super().__init__() + self.lm_head = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype) + self.ce_loss = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction=reduction) + + def forward(self, x, target): + logits = self.lm_head(x).to(torch.float32) + return self.ce_loss(logits, target) + + +class LigerLMHeadCE(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + ignore_index: int = -100, + reduction: str = "mean", + ): + super().__init__() + self.lm_head = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype) + self.flce = LigerFusedLinearCrossEntropyHelion(ignore_index=ignore_index, reduction=reduction) + + def forward(self, x, target): + return self.flce(x, self.lm_head.weight, target) + + +if __name__ == "__main__": + torch.manual_seed(0) + torch.cuda.manual_seed(0) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.enabled = True + + device = "cuda" + + batch_size = 2 + seq_len = 1024 + hidden_size = 4096 + vocab_size = 32000 + dtype = torch.float32 + reduction = "none" + ignore_index = -100 + + input = torch.randn(batch_size * seq_len, hidden_size, device=device, requires_grad=True) + weight = torch.randn(vocab_size, hidden_size, device=device, requires_grad=True) + target = torch.randint(0, vocab_size, (batch_size * seq_len,), device=device) + + # Init + ref_lm_head_ce = TorchLMHeadCE(hidden_size, vocab_size, dtype=dtype, reduction=reduction).to(device=device) + liger_lm_head_ce = LigerLMHeadCE(hidden_size, vocab_size, dtype=dtype, reduction=reduction).to(device=device) + + ref_lm_head_ce.lm_head.weight.data = weight.data + liger_lm_head_ce.lm_head.weight.data = weight.data + + ref_input = input.detach().clone().requires_grad_(True) + liger_input = input.detach().clone().requires_grad_(True) + + # Forward pass + ref_loss = ref_lm_head_ce(ref_input, target) + ref_logits = input @ weight.T + liger_loss, liger_logits, liger_lse, liger_neg_target_logits = liger_lm_head_ce(liger_input, target) + + liger_logprobs = torch.nn.functional.log_softmax(liger_logits, dim=-1) + ref_logprobs = torch.nn.functional.log_softmax(ref_logits, dim=-1) + + ref_lse = torch.logsumexp(ref_logits, dim=-1) + ref_neg_target_logits = torch.nn.functional.nll_loss(ref_logits, target, reduction="none") + ref_neg_target_logits2 = torch.masked_select( + ref_logits, mask=target[:, None] == torch.arange(vocab_size, device=ref_logits.device)[None, :] + ) + + + for i in range(5): + print("=" * 30 + f"(i = {i})" + "=" * 30) + print(f"{ref_lse[i]=}") + print(f"{ref_neg_target_logits[i]=}") + print(f"{ref_neg_target_logits[i] + ref_lse[i]=}") + print(f"{ref_loss[i]=}") + print(f"{liger_loss[i]=}") + print("=" * 64) + + torch.testing.assert_close(liger_logprobs, ref_logprobs, rtol=1e-1, atol=1e-1) + torch.testing.assert_close(liger_lse, ref_lse, rtol=1e-1, atol=1e-1) + torch.testing.assert_close(liger_neg_target_logits, ref_neg_target_logits, rtol=1e-1, atol=1e-1) + torch.testing.assert_close(ref_loss, liger_loss, rtol=1e-1, atol=1e-1) + + # Backward pass + + # ref_loss.backward() + # liger_loss.backward() + + # torch.testing.assert_close(ref_input.grad, liger_input.grad) + # torch.testing.assert_close(ref_lm_head_ce.lm_head.weight.grad, liger_lm_head_ce.lm_head.weight.grad) From fc9a40638ed66f7816231ced84ea88b1b59577ef Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Mon, 3 Nov 2025 21:23:07 +0800 Subject: [PATCH 02/27] compute dx Signed-off-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> --- .../ops/helion/fused_linear_cross_entropy.py | 33 +++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py index 562dae132..8402ca0d3 100644 --- a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py @@ -32,6 +32,8 @@ def fused_linear_cross_entropy_fwd( lse = torch.full((BT,), fill_value=-torch.inf, device=x.device, dtype=torch.float32) # DEBUG nll = torch.zeros(BT, device=x.device, dtype=torch.float32) neg_target_logits = torch.zeros(BT, device=x.device, dtype=torch.float32) + grad_x = torch.zeros(BT, H, device=x.device, dtype=torch.float32) + for tile_bt in hl.tile(BT, block_size=block_size_bt): m_i = hl.zeros([tile_bt], dtype=torch.float32) - float("inf") @@ -47,6 +49,7 @@ def fused_linear_cross_entropy_fwd( weight_tile = weight[tile_v, tile_h] acc = hl.dot(x_tile, weight_tile.T, acc=acc, out_dtype=torch.float32) + logits[tile_bt, tile_v] = acc # DEBUG # online softmax statistics @@ -68,6 +71,27 @@ def fused_linear_cross_entropy_fwd( nll_tile = nll_tile + lse_tile nll[tile_bt] = nll_tile + # gradients computation + for tile_v in hl.tile(V, block_size=block_size_v): + # Restore logits + # acc = hl.zeros([tile_bt, tile_v], dtype=torch.float32) + # for tile_h in hl.tile(H, block_size=block_size_h): + # x_tile = x[tile_bt, tile_h] + # weight_tile = weight[tile_v, tile_h] + # acc = hl.dot(x_tile, weight_tile.T, acc=acc, out_dtype=torch.float32) + + logits_tile = logits[tile_bt, tile_v] + + # softmax(x_i) = exp(x_i) / sum(exp(x_i)) + # = exp(x_i) / log(exp(sum(x_i))) + # = exp(x_i) / lse = exp(x_i - lse) + grad_logits_tile = torch.exp(logits_tile - lse_tile[:, None]) + + # grad_x = grad_logits @ weight + for tile_h in hl.tile(H, block_size=block_size_h): + weight_tile = weight[tile_v, tile_h] + partial_grad_x = hl.dot(grad_logits_tile, weight_tile, out_dtype=torch.float32) + hl.atomic_add(grad_x, [tile_bt, tile_h], partial_grad_x) if reduction == "mean": loss = nll.sum() / nll.numel() @@ -76,7 +100,9 @@ def fused_linear_cross_entropy_fwd( else: loss = nll - return loss, logits.to(x.dtype), lse, neg_target_logits + return loss, logits.to(x.dtype), lse, neg_target_logits, grad_x + + # class LigerFusedLinearCrossEntropyHelionFunction(torch.autograd.Function): @@ -192,7 +218,7 @@ def forward(self, x, target): # Forward pass ref_loss = ref_lm_head_ce(ref_input, target) ref_logits = input @ weight.T - liger_loss, liger_logits, liger_lse, liger_neg_target_logits = liger_lm_head_ce(liger_input, target) + liger_loss, liger_logits, liger_lse, liger_neg_target_logits, liger_grad_x = liger_lm_head_ce(liger_input, target) liger_logprobs = torch.nn.functional.log_softmax(liger_logits, dim=-1) ref_logprobs = torch.nn.functional.log_softmax(ref_logits, dim=-1) @@ -218,7 +244,10 @@ def forward(self, x, target): torch.testing.assert_close(liger_neg_target_logits, ref_neg_target_logits, rtol=1e-1, atol=1e-1) torch.testing.assert_close(ref_loss, liger_loss, rtol=1e-1, atol=1e-1) + # Backward pass + ref_loss.backward() + torch.testing.assert_close(liger_grad_x, ref_input.grad, rtol=1e-1, atol=1e-1) # ref_loss.backward() # liger_loss.backward() From 5c816488438ecc09a54a77165a3548c7281d944e Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Mon, 3 Nov 2025 23:35:50 +0800 Subject: [PATCH 03/27] add grad_x, grad_w computation Signed-off-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> --- .../ops/helion/fused_linear_cross_entropy.py | 80 ++++++++++++------- 1 file changed, 53 insertions(+), 27 deletions(-) diff --git a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py index 8402ca0d3..f7da9279b 100644 --- a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py @@ -32,7 +32,9 @@ def fused_linear_cross_entropy_fwd( lse = torch.full((BT,), fill_value=-torch.inf, device=x.device, dtype=torch.float32) # DEBUG nll = torch.zeros(BT, device=x.device, dtype=torch.float32) neg_target_logits = torch.zeros(BT, device=x.device, dtype=torch.float32) - grad_x = torch.zeros(BT, H, device=x.device, dtype=torch.float32) + grad_x = torch.zeros_like(x, dtype=torch.float32) + grad_w = torch.zeros_like(weight, dtype=torch.float32) + grad_logits = torch.zeros_like(logits, dtype=torch.float32) for tile_bt in hl.tile(BT, block_size=block_size_bt): @@ -49,7 +51,6 @@ def fused_linear_cross_entropy_fwd( weight_tile = weight[tile_v, tile_h] acc = hl.dot(x_tile, weight_tile.T, acc=acc, out_dtype=torch.float32) - logits[tile_bt, tile_v] = acc # DEBUG # online softmax statistics @@ -82,27 +83,44 @@ def fused_linear_cross_entropy_fwd( logits_tile = logits[tile_bt, tile_v] - # softmax(x_i) = exp(x_i) / sum(exp(x_i)) + # softmax(x_i) = exp(x_i) / sum(exp(x_i)) # = exp(x_i) / log(exp(sum(x_i))) # = exp(x_i) / lse = exp(x_i - lse) - grad_logits_tile = torch.exp(logits_tile - lse_tile[:, None]) - - # grad_x = grad_logits @ weight + grad_logits_tile = torch.exp(logits_tile - lse_tile[:, None]) + offset = tile_v.index.unsqueeze(0) # [1, tile_v] + mask = target_indices == offset # [tile_bt, tile_v] + grad_logits_tile = grad_logits_tile - mask.float() + grad_logits[tile_bt, tile_v] = grad_logits_tile + for tile_h in hl.tile(H, block_size=block_size_h): - weight_tile = weight[tile_v, tile_h] - partial_grad_x = hl.dot(grad_logits_tile, weight_tile, out_dtype=torch.float32) + # grad_x = grad_logits @ weight + rhs_tile = weight[tile_v, tile_h] + partial_grad_x = hl.dot(grad_logits_tile, rhs_tile, out_dtype=torch.float32) hl.atomic_add(grad_x, [tile_bt, tile_h], partial_grad_x) + # grad_w = grad_logits.T[tile_v, tile_bt] @ x[tile_bt, tile_h] + rhs_tile = x[tile_bt, tile_h] + partial_grad_w = hl.dot(grad_logits_tile.T, rhs_tile, out_dtype=torch.float32) + hl.atomic_add(grad_w, [tile_v, tile_h], partial_grad_w) if reduction == "mean": - loss = nll.sum() / nll.numel() + loss = nll.mean() elif reduction == "sum": loss = nll.sum() else: loss = nll - return loss, logits.to(x.dtype), lse, neg_target_logits, grad_x - - + return dict( + { + "loss": loss, + "grad_x": grad_x, + "grad_w": grad_w, + "grad_logits": grad_logits, + "lse": lse, + "neg_target_logits": neg_target_logits, + "logits": logits, + "nll": nll, + } + ) # class LigerFusedLinearCrossEntropyHelionFunction(torch.autograd.Function): @@ -161,10 +179,12 @@ def __init__( super().__init__() self.lm_head = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype) self.ce_loss = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction=reduction) + self.logits = None def forward(self, x, target): - logits = self.lm_head(x).to(torch.float32) - return self.ce_loss(logits, target) + self.logits = self.lm_head(x).to(torch.float32) + self.logits.retain_grad() + return self.ce_loss(self.logits, target) class LigerLMHeadCE(torch.nn.Module): @@ -195,8 +215,8 @@ def forward(self, x, target): batch_size = 2 seq_len = 1024 - hidden_size = 4096 - vocab_size = 32000 + hidden_size = 1024 + vocab_size = 2048 dtype = torch.float32 reduction = "none" ignore_index = -100 @@ -216,9 +236,18 @@ def forward(self, x, target): liger_input = input.detach().clone().requires_grad_(True) # Forward pass - ref_loss = ref_lm_head_ce(ref_input, target) + ref_loss: torch.Tensor = ref_lm_head_ce(ref_input, target) ref_logits = input @ weight.T - liger_loss, liger_logits, liger_lse, liger_neg_target_logits, liger_grad_x = liger_lm_head_ce(liger_input, target) + liger_output = liger_lm_head_ce(liger_input, target) + + liger_loss = liger_output["loss"] + liger_grad_x = liger_output["grad_x"] + liger_grad_w = liger_output["grad_w"] + liger_lse = liger_output["lse"] + liger_neg_target_logits = liger_output["neg_target_logits"] + liger_logits = liger_output["logits"] + liger_grad_logits = liger_output["grad_logits"] + liger_logprobs = torch.nn.functional.log_softmax(liger_logits, dim=-1) ref_logprobs = torch.nn.functional.log_softmax(ref_logits, dim=-1) @@ -229,7 +258,6 @@ def forward(self, x, target): ref_logits, mask=target[:, None] == torch.arange(vocab_size, device=ref_logits.device)[None, :] ) - for i in range(5): print("=" * 30 + f"(i = {i})" + "=" * 30) print(f"{ref_lse[i]=}") @@ -242,15 +270,13 @@ def forward(self, x, target): torch.testing.assert_close(liger_logprobs, ref_logprobs, rtol=1e-1, atol=1e-1) torch.testing.assert_close(liger_lse, ref_lse, rtol=1e-1, atol=1e-1) torch.testing.assert_close(liger_neg_target_logits, ref_neg_target_logits, rtol=1e-1, atol=1e-1) - torch.testing.assert_close(ref_loss, liger_loss, rtol=1e-1, atol=1e-1) - + torch.testing.assert_close(liger_loss, ref_loss, rtol=1e-1, atol=1e-1) # Backward pass - ref_loss.backward() + ref_loss.backward(torch.ones_like(ref_loss), retain_graph=True) + assert liger_grad_logits is not None + assert ref_lm_head_ce.logits.grad is not None + torch.testing.assert_close(liger_grad_logits, ref_lm_head_ce.logits.grad, rtol=1e-1, atol=1e-1) torch.testing.assert_close(liger_grad_x, ref_input.grad, rtol=1e-1, atol=1e-1) + torch.testing.assert_close(liger_grad_w, ref_lm_head_ce.lm_head.weight.grad, rtol=1e-1, atol=1e-1) - # ref_loss.backward() - # liger_loss.backward() - - # torch.testing.assert_close(ref_input.grad, liger_input.grad) - # torch.testing.assert_close(ref_lm_head_ce.lm_head.weight.grad, liger_lm_head_ce.lm_head.weight.grad) From 8e9c13e6e38f4adbbc1243b87ddeb3a1f7d96796 Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Tue, 4 Nov 2025 00:46:47 +0800 Subject: [PATCH 04/27] clean up Signed-off-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> --- .../ops/helion/fused_linear_cross_entropy.py | 171 +++++++----------- 1 file changed, 67 insertions(+), 104 deletions(-) diff --git a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py index f7da9279b..96b12f7cd 100644 --- a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py @@ -3,8 +3,9 @@ import torch -@helion.kernel(autotune_effort="none") -def fused_linear_cross_entropy_fwd( +# TODO: autotune and find the best configs for different devices +@helion.kernel(autotune_effort="none", ignore_warnings=[helion.exc.TensorOperationInWrapper]) +def fused_linear_cross_entropy_fwd_bwd( x: torch.Tensor, weight: torch.Tensor, target: torch.Tensor, @@ -20,7 +21,7 @@ def fused_linear_cross_entropy_fwd( ignore_index: index to ignore in the target reduction: reduction to apply to the loss Returns: - loss: loss tensor of shape [1] + loss: loss tensor of shape [1] if reduction is "mean" or "sum", [BT] otherwise """ BT, H = x.size() V = weight.size(0) @@ -28,14 +29,15 @@ def fused_linear_cross_entropy_fwd( block_size_h = hl.register_block_size(H) block_size_v = hl.register_block_size(V) - logits = torch.empty(BT, V, device=x.device, dtype=torch.float32) - lse = torch.full((BT,), fill_value=-torch.inf, device=x.device, dtype=torch.float32) # DEBUG - nll = torch.zeros(BT, device=x.device, dtype=torch.float32) - neg_target_logits = torch.zeros(BT, device=x.device, dtype=torch.float32) + nll = torch.zeros(BT, device=x.device, dtype=torch.float32) grad_x = torch.zeros_like(x, dtype=torch.float32) grad_w = torch.zeros_like(weight, dtype=torch.float32) - grad_logits = torch.zeros_like(logits, dtype=torch.float32) + # May be useful for splitting fwd and bwd + # lse = torch.full((BT,), fill_value=-torch.inf, device=x.device, dtype=torch.float32) + # neg_target_logits = torch.zeros(BT, device=x.device, dtype=torch.float32) + + n_non_ignore = (target != ignore_index).sum().unsqueeze(0) for tile_bt in hl.tile(BT, block_size=block_size_bt): m_i = hl.zeros([tile_bt], dtype=torch.float32) - float("inf") @@ -51,7 +53,6 @@ def fused_linear_cross_entropy_fwd( weight_tile = weight[tile_v, tile_h] acc = hl.dot(x_tile, weight_tile.T, acc=acc, out_dtype=torch.float32) - logits[tile_bt, tile_v] = acc # DEBUG # online softmax statistics m_ij = torch.maximum(m_i, torch.amax(acc, dim=-1)) @@ -65,32 +66,28 @@ def fused_linear_cross_entropy_fwd( # loss computation: -logsoftmax(x_y) = -log(exp(x_y) / sum(exp(x_i))) = -x_y + log(sum(exp(x_i))) lse_tile = m_i + torch.log(d_i) - lse[tile_bt] = lse_tile - - neg_target_logits[tile_bt] = nll_tile - nll_tile = nll_tile + lse_tile nll[tile_bt] = nll_tile # gradients computation for tile_v in hl.tile(V, block_size=block_size_v): # Restore logits - # acc = hl.zeros([tile_bt, tile_v], dtype=torch.float32) - # for tile_h in hl.tile(H, block_size=block_size_h): - # x_tile = x[tile_bt, tile_h] - # weight_tile = weight[tile_v, tile_h] - # acc = hl.dot(x_tile, weight_tile.T, acc=acc, out_dtype=torch.float32) + acc = hl.zeros([tile_bt, tile_v], dtype=torch.float32) + for tile_h in hl.tile(H, block_size=block_size_h): + x_tile = x[tile_bt, tile_h] + weight_tile = weight[tile_v, tile_h] + acc = hl.dot(x_tile, weight_tile.T, acc=acc, out_dtype=torch.float32) - logits_tile = logits[tile_bt, tile_v] # softmax(x_i) = exp(x_i) / sum(exp(x_i)) # = exp(x_i) / log(exp(sum(x_i))) # = exp(x_i) / lse = exp(x_i - lse) - grad_logits_tile = torch.exp(logits_tile - lse_tile[:, None]) + grad_logits_tile = torch.exp(acc - lse_tile[:, None]) offset = tile_v.index.unsqueeze(0) # [1, tile_v] mask = target_indices == offset # [tile_bt, tile_v] grad_logits_tile = grad_logits_tile - mask.float() - grad_logits[tile_bt, tile_v] = grad_logits_tile + n_non_ignore_value = hl.load(n_non_ignore, [0]) + grad_logits_tile /= n_non_ignore_value for tile_h in hl.tile(H, block_size=block_size_h): # grad_x = grad_logits @ weight @@ -103,51 +100,46 @@ def fused_linear_cross_entropy_fwd( hl.atomic_add(grad_w, [tile_v, tile_h], partial_grad_w) if reduction == "mean": - loss = nll.mean() + loss = nll.sum() / n_non_ignore.squeeze() elif reduction == "sum": loss = nll.sum() else: loss = nll - return dict( + # return format is not determined yet + return loss, dict( { - "loss": loss, - "grad_x": grad_x, - "grad_w": grad_w, - "grad_logits": grad_logits, - "lse": lse, - "neg_target_logits": neg_target_logits, - "logits": logits, - "nll": nll, + "grad_x": grad_x, + "grad_w": grad_w, } ) -# class LigerFusedLinearCrossEntropyHelionFunction(torch.autograd.Function): -# @staticmethod -# def forward( -# ctx, -# _input, -# weight, -# target, -# ignore_index=-100, -# reduction="mean", -# ): -# assert _input.ndim == weight.ndim -# loss, grad_input, grad_weight = fused_linear_cross_entropy_fwd_bwd( -# _input, -# weight, -# target, -# ignore_index, -# reduction, -# ) -# ctx.save_for_backward(grad_input, grad_weight) -# return loss - -# @staticmethod -# def backward(ctx, grad_output): -# grad_input, grad_weight = ctx.saved_tensors -# return grad_input * grad_output, grad_weight * grad_output, None, None, None +class LigerFusedLinearCrossEntropyHelionFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + _input, + weight, + target, + ignore_index=-100, + reduction="mean", + ): + loss, aux_output = fused_linear_cross_entropy_fwd_bwd( + _input, + weight, + target, + ignore_index, + reduction, + ) + ctx.save_for_backward(aux_output["grad_x"], aux_output["grad_w"]) + return loss + + @staticmethod + def backward(ctx, grad_output): + assert grad_output.ndim == 0, "token_scaling is not supported. grad_output must be a scalar" + grad_input, grad_weight = ctx.saved_tensors + return grad_input * grad_output, grad_weight * grad_output, None, None, None class LigerFusedLinearCrossEntropyHelion(torch.nn.Module): @@ -157,15 +149,13 @@ def __init__(self, ignore_index=-100, reduction="mean"): self.reduction = reduction def forward(self, _input, weight, target): - # return LigerFusedLinearCrossEntropyHelionFunction.apply( - # _input, - # weight, - # target, - # self.ignore_index, - # self.reduction - # ) - return fused_linear_cross_entropy_fwd(_input, weight, target, self.ignore_index, self.reduction) - + return LigerFusedLinearCrossEntropyHelionFunction.apply( + _input, + weight, + target, + self.ignore_index, + self.reduction + ) class TorchLMHeadCE(torch.nn.Module): def __init__( @@ -183,7 +173,6 @@ def __init__( def forward(self, x, target): self.logits = self.lm_head(x).to(torch.float32) - self.logits.retain_grad() return self.ce_loss(self.logits, target) @@ -218,7 +207,7 @@ def forward(self, x, target): hidden_size = 1024 vocab_size = 2048 dtype = torch.float32 - reduction = "none" + reduction = "mean" ignore_index = -100 input = torch.randn(batch_size * seq_len, hidden_size, device=device, requires_grad=True) @@ -237,46 +226,20 @@ def forward(self, x, target): # Forward pass ref_loss: torch.Tensor = ref_lm_head_ce(ref_input, target) - ref_logits = input @ weight.T - liger_output = liger_lm_head_ce(liger_input, target) - - liger_loss = liger_output["loss"] - liger_grad_x = liger_output["grad_x"] - liger_grad_w = liger_output["grad_w"] - liger_lse = liger_output["lse"] - liger_neg_target_logits = liger_output["neg_target_logits"] - liger_logits = liger_output["logits"] - liger_grad_logits = liger_output["grad_logits"] + liger_loss: torch.Tensor = liger_lm_head_ce(liger_input, target) + torch.testing.assert_close(liger_loss, ref_loss, rtol=1e-1, atol=1e-1) + + # Backward pass (backward() with reduction=="none" is not supported yet) + if reduction == "none": + pass + else: + liger_loss.backward() + ref_loss.backward() - liger_logprobs = torch.nn.functional.log_softmax(liger_logits, dim=-1) - ref_logprobs = torch.nn.functional.log_softmax(ref_logits, dim=-1) + torch.testing.assert_close(liger_input.grad, ref_input.grad, rtol=1e-1, atol=1e-1) + torch.testing.assert_close(liger_lm_head_ce.lm_head.weight.grad, ref_lm_head_ce.lm_head.weight.grad, rtol=1e-1, atol=1e-1) - ref_lse = torch.logsumexp(ref_logits, dim=-1) - ref_neg_target_logits = torch.nn.functional.nll_loss(ref_logits, target, reduction="none") - ref_neg_target_logits2 = torch.masked_select( - ref_logits, mask=target[:, None] == torch.arange(vocab_size, device=ref_logits.device)[None, :] - ) - for i in range(5): - print("=" * 30 + f"(i = {i})" + "=" * 30) - print(f"{ref_lse[i]=}") - print(f"{ref_neg_target_logits[i]=}") - print(f"{ref_neg_target_logits[i] + ref_lse[i]=}") - print(f"{ref_loss[i]=}") - print(f"{liger_loss[i]=}") - print("=" * 64) - - torch.testing.assert_close(liger_logprobs, ref_logprobs, rtol=1e-1, atol=1e-1) - torch.testing.assert_close(liger_lse, ref_lse, rtol=1e-1, atol=1e-1) - torch.testing.assert_close(liger_neg_target_logits, ref_neg_target_logits, rtol=1e-1, atol=1e-1) - torch.testing.assert_close(liger_loss, ref_loss, rtol=1e-1, atol=1e-1) - # Backward pass - ref_loss.backward(torch.ones_like(ref_loss), retain_graph=True) - assert liger_grad_logits is not None - assert ref_lm_head_ce.logits.grad is not None - torch.testing.assert_close(liger_grad_logits, ref_lm_head_ce.logits.grad, rtol=1e-1, atol=1e-1) - torch.testing.assert_close(liger_grad_x, ref_input.grad, rtol=1e-1, atol=1e-1) - torch.testing.assert_close(liger_grad_w, ref_lm_head_ce.lm_head.weight.grad, rtol=1e-1, atol=1e-1) From 368433b722e53e3928ac006bc2bd59bf7e911669 Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Tue, 4 Nov 2025 00:47:09 +0800 Subject: [PATCH 05/27] format Signed-off-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> --- .../ops/helion/fused_linear_cross_entropy.py | 31 +++++++------------ 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py index 96b12f7cd..8d0da00cf 100644 --- a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py @@ -29,13 +29,13 @@ def fused_linear_cross_entropy_fwd_bwd( block_size_h = hl.register_block_size(H) block_size_v = hl.register_block_size(V) - nll = torch.zeros(BT, device=x.device, dtype=torch.float32) + nll = torch.zeros(BT, device=x.device, dtype=torch.float32) grad_x = torch.zeros_like(x, dtype=torch.float32) grad_w = torch.zeros_like(weight, dtype=torch.float32) # May be useful for splitting fwd and bwd - # lse = torch.full((BT,), fill_value=-torch.inf, device=x.device, dtype=torch.float32) - # neg_target_logits = torch.zeros(BT, device=x.device, dtype=torch.float32) + # lse = torch.full((BT,), fill_value=-torch.inf, device=x.device, dtype=torch.float32) + # neg_target_logits = torch.zeros(BT, device=x.device, dtype=torch.float32) n_non_ignore = (target != ignore_index).sum().unsqueeze(0) @@ -53,7 +53,6 @@ def fused_linear_cross_entropy_fwd_bwd( weight_tile = weight[tile_v, tile_h] acc = hl.dot(x_tile, weight_tile.T, acc=acc, out_dtype=torch.float32) - # online softmax statistics m_ij = torch.maximum(m_i, torch.amax(acc, dim=-1)) d_i = d_i * torch.exp(m_i - m_ij) + torch.exp(acc - m_ij[:, None]).sum(dim=-1) @@ -78,7 +77,6 @@ def fused_linear_cross_entropy_fwd_bwd( weight_tile = weight[tile_v, tile_h] acc = hl.dot(x_tile, weight_tile.T, acc=acc, out_dtype=torch.float32) - # softmax(x_i) = exp(x_i) / sum(exp(x_i)) # = exp(x_i) / log(exp(sum(x_i))) # = exp(x_i) / lse = exp(x_i - lse) @@ -107,10 +105,10 @@ def fused_linear_cross_entropy_fwd_bwd( loss = nll # return format is not determined yet - return loss, dict( + return loss, dict( { - "grad_x": grad_x, - "grad_w": grad_w, + "grad_x": grad_x, + "grad_w": grad_w, } ) @@ -150,13 +148,10 @@ def __init__(self, ignore_index=-100, reduction="mean"): def forward(self, _input, weight, target): return LigerFusedLinearCrossEntropyHelionFunction.apply( - _input, - weight, - target, - self.ignore_index, - self.reduction + _input, weight, target, self.ignore_index, self.reduction ) + class TorchLMHeadCE(torch.nn.Module): def __init__( self, @@ -229,7 +224,7 @@ def forward(self, x, target): liger_loss: torch.Tensor = liger_lm_head_ce(liger_input, target) torch.testing.assert_close(liger_loss, ref_loss, rtol=1e-1, atol=1e-1) - + # Backward pass (backward() with reduction=="none" is not supported yet) if reduction == "none": pass @@ -238,8 +233,6 @@ def forward(self, x, target): ref_loss.backward() torch.testing.assert_close(liger_input.grad, ref_input.grad, rtol=1e-1, atol=1e-1) - torch.testing.assert_close(liger_lm_head_ce.lm_head.weight.grad, ref_lm_head_ce.lm_head.weight.grad, rtol=1e-1, atol=1e-1) - - - - + torch.testing.assert_close( + liger_lm_head_ce.lm_head.weight.grad, ref_lm_head_ce.lm_head.weight.grad, rtol=1e-1, atol=1e-1 + ) From 164e63ada8c809505e4aa24fecad47a51dd7217a Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Tue, 4 Nov 2025 11:46:46 +0800 Subject: [PATCH 06/27] Fix incorrect grad_w computation with reduction="mean" Signed-off-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> --- src/liger_kernel/ops/helion/fused_linear_cross_entropy.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py index 8d0da00cf..e50c28fe8 100644 --- a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py @@ -85,7 +85,8 @@ def fused_linear_cross_entropy_fwd_bwd( mask = target_indices == offset # [tile_bt, tile_v] grad_logits_tile = grad_logits_tile - mask.float() n_non_ignore_value = hl.load(n_non_ignore, [0]) - grad_logits_tile /= n_non_ignore_value + if reduction == "mean": + grad_logits_tile /= n_non_ignore_value for tile_h in hl.tile(H, block_size=block_size_h): # grad_x = grad_logits @ weight @@ -202,7 +203,7 @@ def forward(self, x, target): hidden_size = 1024 vocab_size = 2048 dtype = torch.float32 - reduction = "mean" + reduction = "sum" ignore_index = -100 input = torch.randn(batch_size * seq_len, hidden_size, device=device, requires_grad=True) From b50870cd33a22b692b74e055415c675b70750c10 Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Tue, 4 Nov 2025 11:47:10 +0800 Subject: [PATCH 07/27] Add unit test Signed-off-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> --- .../helion/test_fused_linear_cross_entropy.py | 140 ++++++++++++++++++ 1 file changed, 140 insertions(+) create mode 100644 test/transformers/helion/test_fused_linear_cross_entropy.py diff --git a/test/transformers/helion/test_fused_linear_cross_entropy.py b/test/transformers/helion/test_fused_linear_cross_entropy.py new file mode 100644 index 000000000..95c14fa11 --- /dev/null +++ b/test/transformers/helion/test_fused_linear_cross_entropy.py @@ -0,0 +1,140 @@ +import os +import random + +import numpy as np +import pytest +import torch + +from liger_kernel.ops.helion.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyHelion +from liger_kernel.utils import infer_device + +device = infer_device() + +def supports_bfloat16(): + if device == "cuda": + return torch.cuda.get_device_capability() >= (8, 0) # Ampere and newer + elif device == "xpu": + return True + else: + return False + +def set_seed(seed=42): + """ + Fix all random seeds we use for reproducibility. + """ + # Python random seed + random.seed(seed) + # Numpy random seed + np.random.seed(0) + # PyTorch random seed + torch.manual_seed(seed) + + if device == "cuda": + # If you are using CUDA + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. + + # PyTorch backend settings + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + elif device == "xpu": + # If you are using XPU + torch.xpu.manual_seed(seed) + torch.xpu.manual_seed_all(seed) + + # Python hash seed + os.environ["PYTHONHASHSEED"] = str(seed) + +set_seed(42) + +class TorchLMHeadCE(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + ignore_index: int = -100, + reduction: str = "mean", + ): + super().__init__() + self.lm_head = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype) + self.ce_loss = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction=reduction) + + def forward(self, x, target): + logits = self.lm_head(x).to(torch.float32) + return self.ce_loss(logits, target) + + +class LigerLMHeadCE(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + ignore_index: int = -100, + reduction: str = "mean", + ): + super().__init__() + self.lm_head = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype) + self.flce = LigerFusedLinearCrossEntropyHelion(ignore_index=ignore_index, reduction=reduction) + + def forward(self, x, target): + return self.flce(x, self.lm_head.weight, target) + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (2, 1024, 4096, 32000), # llama + (3, 423, 1000, 10000), # weird shapes + ], +) +@pytest.mark.parametrize("reduction", ["sum", "mean", "none"]) +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + pytest.param( + torch.bfloat16, + 1e-2, + 1e-2, + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + ), + (torch.float32, 1e-3, 1e-2), + ], +) +def test_fused_linear_cross_entropy_correctness(B, T, H, V, reduction, dtype, atol, rtol): + input = torch.randn(B * T, H, device=device, requires_grad=True) + weight = torch.randn(V, H, device=device, requires_grad=True) + target = torch.randint(0, V, (B * T,), device=device) + + ref_lm_head_ce = TorchLMHeadCE(H, V, dtype=dtype, reduction=reduction).to(device=device) + liger_lm_head_ce = LigerLMHeadCE(H, V, dtype=dtype, reduction=reduction).to(device=device) + + ref_lm_head_ce.lm_head.weight.data = weight.data + liger_lm_head_ce.lm_head.weight.data = weight.data + + ref_input = input.detach().clone().requires_grad_(True) + liger_input = input.detach().clone().requires_grad_(True) + + # Forward pass + ref_loss: torch.Tensor = ref_lm_head_ce(ref_input, target) + liger_loss: torch.Tensor = liger_lm_head_ce(liger_input, target) + + torch.testing.assert_close(liger_loss, ref_loss, rtol=rtol, atol=atol) + + # Backward pass (backward() with reduction=="none" is not supported yet) + if reduction == "none": + pass + else: + liger_loss.backward() + ref_loss.backward() + + assert liger_lm_head_ce.lm_head.weight.grad.isnan().sum() == 0, "lm_head.weight of liger contains nan" + assert ref_lm_head_ce.lm_head.weight.grad.isnan().sum() == 0, "lm_head.weight of ref contains nan" + assert liger_input.grad.isnan().sum() == 0 + assert liger_input.grad.isinf().sum() == 0 + torch.testing.assert_close(liger_input.grad, ref_input.grad, rtol=rtol, atol=atol) + torch.testing.assert_close( + liger_lm_head_ce.lm_head.weight.grad, ref_lm_head_ce.lm_head.weight.grad, rtol=rtol, atol=atol, + ) + + From f890711bd567c3d68e5ba5f477acd3d9c22bff57 Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Tue, 4 Nov 2025 12:55:22 +0800 Subject: [PATCH 08/27] Improve n_non_ignore read efficiency and ERROR comments Signed-off-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> --- .../ops/helion/fused_linear_cross_entropy.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py index e50c28fe8..be0e146e0 100644 --- a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py @@ -43,7 +43,10 @@ def fused_linear_cross_entropy_fwd_bwd( m_i = hl.zeros([tile_bt], dtype=torch.float32) - float("inf") d_i = hl.zeros([tile_bt], dtype=torch.float32) nll_tile = hl.zeros([tile_bt], dtype=torch.float32) - # target_indices = target[tile_bt][:, None] # [tile_bt, 1] # ERROR + if reduction == "mean": + n_non_ignore_value = hl.load(n_non_ignore, [0]) + + # target_indices = target[tile_bt][:, None] # ERROR: it introduces a new size, which is not broadcastable target_indices = target[tile_bt].unsqueeze(1) # [tile_bt, 1] for tile_v in hl.tile(V, block_size=block_size_v): # logits computation @@ -58,7 +61,7 @@ def fused_linear_cross_entropy_fwd_bwd( d_i = d_i * torch.exp(m_i - m_ij) + torch.exp(acc - m_ij[:, None]).sum(dim=-1) m_i = m_ij - # offset = tile_v.index[None, :] # [1, tile_v] # ERROR + # offset = tile_v.index[None, :] # ERROR: it introduces a new size, which is not broadcastable offset = tile_v.index.unsqueeze(0) # [1, tile_v] mask = target_indices == offset # [tile_bt, tile_v] nll_tile += torch.sum(-acc * mask, dim=-1) # [tile_bt] @@ -66,8 +69,11 @@ def fused_linear_cross_entropy_fwd_bwd( # loss computation: -logsoftmax(x_y) = -log(exp(x_y) / sum(exp(x_i))) = -x_y + log(sum(exp(x_i))) lse_tile = m_i + torch.log(d_i) nll_tile = nll_tile + lse_tile - nll[tile_bt] = nll_tile + if reduction == "mean": + nll_tile /= n_non_ignore_value + + nll[tile_bt] = nll_tile # gradients computation for tile_v in hl.tile(V, block_size=block_size_v): # Restore logits @@ -84,7 +90,9 @@ def fused_linear_cross_entropy_fwd_bwd( offset = tile_v.index.unsqueeze(0) # [1, tile_v] mask = target_indices == offset # [tile_bt, tile_v] grad_logits_tile = grad_logits_tile - mask.float() - n_non_ignore_value = hl.load(n_non_ignore, [0]) + # handle out of bound values in grad_logits_tile + grad_logits_tile = grad_logits_tile * ((tile_bt.index < BT)[:, None] & (tile_v.index < V)[None, :]) + if reduction == "mean": grad_logits_tile /= n_non_ignore_value @@ -98,10 +106,8 @@ def fused_linear_cross_entropy_fwd_bwd( partial_grad_w = hl.dot(grad_logits_tile.T, rhs_tile, out_dtype=torch.float32) hl.atomic_add(grad_w, [tile_v, tile_h], partial_grad_w) - if reduction == "mean": - loss = nll.sum() / n_non_ignore.squeeze() - elif reduction == "sum": - loss = nll.sum() + if reduction != "none": + loss = nll.sum() else: loss = nll From 04f07fa5f57396ce8d22702cdfedd48fa0607c0f Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Tue, 4 Nov 2025 12:55:35 +0800 Subject: [PATCH 09/27] Set higher tolerance Signed-off-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> --- .../helion/test_fused_linear_cross_entropy.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/test/transformers/helion/test_fused_linear_cross_entropy.py b/test/transformers/helion/test_fused_linear_cross_entropy.py index 95c14fa11..a0a935e0d 100644 --- a/test/transformers/helion/test_fused_linear_cross_entropy.py +++ b/test/transformers/helion/test_fused_linear_cross_entropy.py @@ -1,5 +1,6 @@ import os import random +import warnings import numpy as np import pytest @@ -94,11 +95,11 @@ def forward(self, x, target): [ pytest.param( torch.bfloat16, - 1e-2, + 1e-1, 1e-2, marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), - (torch.float32, 1e-3, 1e-2), + (torch.float32, 1e-1, 1e-2), ], ) def test_fused_linear_cross_entropy_correctness(B, T, H, V, reduction, dtype, atol, rtol): @@ -123,7 +124,8 @@ def test_fused_linear_cross_entropy_correctness(B, T, H, V, reduction, dtype, at # Backward pass (backward() with reduction=="none" is not supported yet) if reduction == "none": - pass + warnings.warn("backward() with reduction='none' is not supported yet", UserWarning) + else: liger_loss.backward() ref_loss.backward() From 8781e8ed9842ac8a70b6e6a5a7fb49825d5da2c6 Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Tue, 4 Nov 2025 08:25:48 +0000 Subject: [PATCH 10/27] Add benchmark Signed-off-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> --- .../ops/helion/fused_linear_cross_entropy.py | 31 ++++++++++++++----- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py index be0e146e0..3f362d89d 100644 --- a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py @@ -2,9 +2,10 @@ import helion.language as hl import torch +# Best config for default llama3Config(hidden_size=4096, vocab_size=32000) with 4096 bs*seqlen input +config = helion.Config(block_sizes=[32, 32, 256], indexing=['tensor_descriptor', 'tensor_descriptor', 'pointer', 'tensor_descriptor', 'pointer', 'pointer', 'tensor_descriptor', 'pointer', 'tensor_descriptor'], load_eviction_policies=['first', '', 'first', 'last', '', '', 'last', 'first'], num_stages=5, num_warps=4, pid_type='flat', range_flattens=[None, True, False], range_multi_buffers=[None, True, False], range_num_stages=[0, 0, 0], range_unroll_factors=[0, 1, 1], range_warp_specializes=[]) -# TODO: autotune and find the best configs for different devices -@helion.kernel(autotune_effort="none", ignore_warnings=[helion.exc.TensorOperationInWrapper]) +@helion.kernel(config=config, ignore_warnings=[helion.exc.TensorOperationInWrapper]) def fused_linear_cross_entropy_fwd_bwd( x: torch.Tensor, weight: torch.Tensor, @@ -204,12 +205,12 @@ def forward(self, x, target): device = "cuda" - batch_size = 2 - seq_len = 1024 - hidden_size = 1024 - vocab_size = 2048 + batch_size = 8 + seq_len = 4096 + hidden_size = 4096 + vocab_size = 32000 dtype = torch.float32 - reduction = "sum" + reduction = "mean" ignore_index = -100 input = torch.randn(batch_size * seq_len, hidden_size, device=device, requires_grad=True) @@ -243,3 +244,19 @@ def forward(self, x, target): torch.testing.assert_close( liger_lm_head_ce.lm_head.weight.grad, ref_lm_head_ce.lm_head.weight.grad, rtol=1e-1, atol=1e-1 ) + + + # Benchmark + from helion._testing import run_example + from functools import partial + + def fwd_bwd_fn(input, target, fn): + loss = fn(input, target) + loss.backward() + return loss + liger_lm_head_ce_fwd_bwd = partial(fwd_bwd_fn, fn=liger_lm_head_ce) + ref_lm_head_ce_fwd_bwd = partial(fwd_bwd_fn, fn=ref_lm_head_ce) + + + run_example(liger_lm_head_ce, ref_lm_head_ce, (input, target), kernel_name="helion_flce_fwd", baseline_name="torch_fwd", rtol=1e-1, atol=1e-1) + run_example(liger_lm_head_ce_fwd_bwd, ref_lm_head_ce_fwd_bwd, (input, target), kernel_name="helion_flce_fwd_bwd", baseline_name="torch_fwd_bwd", rtol=1e-1, atol=1e-1) From 98de5b9c4b8e0d700f9e068db9e6b96b24d24bf3 Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Thu, 6 Nov 2025 15:56:51 +0800 Subject: [PATCH 11/27] Unfuse forward/backward and use lock Signed-off-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> --- .../ops/helion/fused_linear_cross_entropy.py | 162 +++++++++++++----- 1 file changed, 116 insertions(+), 46 deletions(-) diff --git a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py index 3f362d89d..29aafdbc1 100644 --- a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py @@ -1,3 +1,5 @@ +import math + import helion import helion.language as hl import torch @@ -5,8 +7,29 @@ # Best config for default llama3Config(hidden_size=4096, vocab_size=32000) with 4096 bs*seqlen input config = helion.Config(block_sizes=[32, 32, 256], indexing=['tensor_descriptor', 'tensor_descriptor', 'pointer', 'tensor_descriptor', 'pointer', 'pointer', 'tensor_descriptor', 'pointer', 'tensor_descriptor'], load_eviction_policies=['first', '', 'first', 'last', '', '', 'last', 'first'], num_stages=5, num_warps=4, pid_type='flat', range_flattens=[None, True, False], range_multi_buffers=[None, True, False], range_num_stages=[0, 0, 0], range_unroll_factors=[0, 1, 1], range_warp_specializes=[]) -@helion.kernel(config=config, ignore_warnings=[helion.exc.TensorOperationInWrapper]) -def fused_linear_cross_entropy_fwd_bwd( +def helion_lock_acquire(lock_ptr, lock_index): + hl.inline_triton( + """ + while tl.atomic_cas({0} + {1}, 0, 1, sem="acquire") == 1: + pass + """, + args=(lock_ptr, lock_index), + output_like=None, + ) + +def helion_lock_release(lock_ptr, lock_index): + hl.inline_triton( + """ + tl.atomic_xchg({0} + {1}, 0, sem="release") + """, + args=(lock_ptr, lock_index), + output_like=None, + ) + + +# @helion.kernel(config=config, ignore_warnings=[helion.exc.TensorOperationInWrapper]) +@helion.kernel(autotune_effort="none", ignore_warnings=[helion.exc.TensorOperationInWrapper]) +def fused_linear_cross_entropy_fwd( x: torch.Tensor, weight: torch.Tensor, target: torch.Tensor, @@ -31,8 +54,7 @@ def fused_linear_cross_entropy_fwd_bwd( block_size_v = hl.register_block_size(V) nll = torch.zeros(BT, device=x.device, dtype=torch.float32) - grad_x = torch.zeros_like(x, dtype=torch.float32) - grad_w = torch.zeros_like(weight, dtype=torch.float32) + lse = torch.zeros(BT, device=x.device, dtype=torch.float32) # May be useful for splitting fwd and bwd # lse = torch.full((BT,), fill_value=-torch.inf, device=x.device, dtype=torch.float32) @@ -40,6 +62,7 @@ def fused_linear_cross_entropy_fwd_bwd( n_non_ignore = (target != ignore_index).sum().unsqueeze(0) + # forward for tile_bt in hl.tile(BT, block_size=block_size_bt): m_i = hl.zeros([tile_bt], dtype=torch.float32) - float("inf") d_i = hl.zeros([tile_bt], dtype=torch.float32) @@ -75,50 +98,86 @@ def fused_linear_cross_entropy_fwd_bwd( nll_tile /= n_non_ignore_value nll[tile_bt] = nll_tile - # gradients computation - for tile_v in hl.tile(V, block_size=block_size_v): - # Restore logits - acc = hl.zeros([tile_bt, tile_v], dtype=torch.float32) - for tile_h in hl.tile(H, block_size=block_size_h): - x_tile = x[tile_bt, tile_h] - weight_tile = weight[tile_v, tile_h] - acc = hl.dot(x_tile, weight_tile.T, acc=acc, out_dtype=torch.float32) - - # softmax(x_i) = exp(x_i) / sum(exp(x_i)) - # = exp(x_i) / log(exp(sum(x_i))) - # = exp(x_i) / lse = exp(x_i - lse) - grad_logits_tile = torch.exp(acc - lse_tile[:, None]) - offset = tile_v.index.unsqueeze(0) # [1, tile_v] - mask = target_indices == offset # [tile_bt, tile_v] - grad_logits_tile = grad_logits_tile - mask.float() - # handle out of bound values in grad_logits_tile - grad_logits_tile = grad_logits_tile * ((tile_bt.index < BT)[:, None] & (tile_v.index < V)[None, :]) - - if reduction == "mean": - grad_logits_tile /= n_non_ignore_value - - for tile_h in hl.tile(H, block_size=block_size_h): - # grad_x = grad_logits @ weight - rhs_tile = weight[tile_v, tile_h] - partial_grad_x = hl.dot(grad_logits_tile, rhs_tile, out_dtype=torch.float32) - hl.atomic_add(grad_x, [tile_bt, tile_h], partial_grad_x) - # grad_w = grad_logits.T[tile_v, tile_bt] @ x[tile_bt, tile_h] - rhs_tile = x[tile_bt, tile_h] - partial_grad_w = hl.dot(grad_logits_tile.T, rhs_tile, out_dtype=torch.float32) - hl.atomic_add(grad_w, [tile_v, tile_h], partial_grad_w) - + lse[tile_bt] = lse_tile + if reduction != "none": loss = nll.sum() else: loss = nll + + return loss, lse + +@helion.kernel(autotune_effort="none", ignore_warnings=[helion.exc.TensorOperationInWrapper]) +def fused_linear_cross_entropy_bwd( + x: torch.Tensor, + weight: torch.Tensor, + target: torch.Tensor, + lse: torch.Tensor, + ignore_index: int = -100, + reduction: str = "mean", +): + BT, H = x.size() + V = weight.size(0) + block_size_bt = hl.register_block_size(BT) + block_size_h = hl.register_block_size(H) + block_size_v = hl.register_block_size(V) + grad_x = torch.zeros_like(x, dtype=torch.float32) + grad_w = torch.zeros_like(weight, dtype=torch.float32) + n_non_ignore = (target != ignore_index).sum().unsqueeze(0) + + num_block_bt = (BT + block_size_bt - 1)//block_size_bt + num_block_h = (H + block_size_h - 1)//block_size_h + num_block_v = (V + block_size_v - 1)//block_size_v + grad_x_lock = torch.zeros((num_block_bt, num_block_h), dtype=torch.int32, device=x.device) + grad_w_lock = torch.zeros((num_block_v, num_block_h), dtype=torch.int32, device=x.device) + # backward + for tile_bt, tile_v in hl.tile([BT, V], block_size=(block_size_bt, block_size_v)): + # Restore logits + acc2 = hl.zeros([tile_bt, tile_v], dtype=torch.float32) + for tile_h in hl.tile(H, block_size=block_size_h): + x_tile = x[tile_bt, tile_h] + weight_tile = weight[tile_v, tile_h] + acc2 = hl.dot(x_tile, weight_tile.T, acc=acc2, out_dtype=torch.float32) + + # softmax(x_i) = exp(x_i) / sum(exp(x_i)) + # = exp(x_i) / log(exp(sum(x_i))) + # = exp(x_i) / lse = exp(x_i - lse) + lse_tile = lse[tile_bt] + target_indices = target[tile_bt].unsqueeze(1) # [tile_bt, 1] + if reduction == "mean": + n_non_ignore_value = hl.load(n_non_ignore, [0]) + + grad_logits_tile = torch.exp(acc2 - lse_tile[:, None]) + offset = tile_v.index.unsqueeze(0) # [1, tile_v] + mask = target_indices == offset # [tile_bt, tile_v] + grad_logits_tile = grad_logits_tile - mask.float() + # handle out of bound values in grad_logits_tile + grad_logits_tile = grad_logits_tile * ((tile_bt.index < BT)[:, None] & (tile_v.index < V)[None, :]) + + if reduction == "mean": + grad_logits_tile /= n_non_ignore_value + + for tile_h in hl.tile(H, block_size=block_size_h): + # grad_x = grad_logits @ weight + rhs_tile = weight[tile_v, tile_h] + partial_grad_x = hl.dot(grad_logits_tile, rhs_tile, out_dtype=torch.float32) + helion_lock_acquire(grad_x_lock, tile_bt.id * num_block_h + tile_h.id) + grad_x[tile_bt, tile_h] += partial_grad_x + helion_lock_release(grad_x_lock, tile_bt.id * num_block_h + tile_h.id) + # hl.atomic_add(grad_x, [tile_bt, tile_h], partial_grad_x) + + # for tile_h in hl.tile(H, block_size=block_size_h): + # grad_w = grad_logits.T[tile_v, tile_bt] @ x[tile_bt, tile_h] + rhs_tile = x[tile_bt, tile_h] + partial_grad_w = hl.dot(grad_logits_tile.T, rhs_tile, out_dtype=torch.float32) + helion_lock_acquire(grad_w_lock, tile_v.id * num_block_h + tile_h.id) + grad_w[tile_v, tile_h] += partial_grad_w + helion_lock_release(grad_w_lock, tile_v.id * num_block_h + tile_h.id) + # hl.atomic_add(grad_w, [tile_v, tile_h], partial_grad_w) + - # return format is not determined yet - return loss, dict( - { - "grad_x": grad_x, - "grad_w": grad_w, - } - ) + + return grad_x, grad_w class LigerFusedLinearCrossEntropyHelionFunction(torch.autograd.Function): @@ -131,20 +190,30 @@ def forward( ignore_index=-100, reduction="mean", ): - loss, aux_output = fused_linear_cross_entropy_fwd_bwd( + loss, lse = fused_linear_cross_entropy_fwd( _input, weight, target, ignore_index, reduction, ) - ctx.save_for_backward(aux_output["grad_x"], aux_output["grad_w"]) + ctx.ignore_index = ignore_index + ctx.reduction = reduction + ctx.save_for_backward(_input, lse) return loss @staticmethod def backward(ctx, grad_output): assert grad_output.ndim == 0, "token_scaling is not supported. grad_output must be a scalar" - grad_input, grad_weight = ctx.saved_tensors + _input, lse = ctx.saved_tensors + grad_input, grad_weight = fused_linear_cross_entropy_bwd( + _input, + weight, + target, + lse, + ctx.ignore_index, + ctx.reduction, + ) return grad_input * grad_output, grad_weight * grad_output, None, None, None @@ -250,6 +319,7 @@ def forward(self, x, target): from helion._testing import run_example from functools import partial + def fwd_bwd_fn(input, target, fn): loss = fn(input, target) loss.backward() From 30b12ebc74de1821d16194d4f2bebb1a06f99acd Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Thu, 6 Nov 2025 15:57:19 +0800 Subject: [PATCH 12/27] testing misc Signed-off-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> --- .../ops/helion/fused_linear_cross_entropy.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py index 29aafdbc1..21d05146f 100644 --- a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py @@ -274,13 +274,19 @@ def forward(self, x, target): device = "cuda" - batch_size = 8 - seq_len = 4096 + batch_size = 2 + seq_len = 1024 hidden_size = 4096 vocab_size = 32000 + # batch_size = 2 + # seq_len = 256 + # hidden_size = 512 + # vocab_size = 1024 dtype = torch.float32 reduction = "mean" ignore_index = -100 + rtol = 1e-2 + atol = 1e-2 input = torch.randn(batch_size * seq_len, hidden_size, device=device, requires_grad=True) weight = torch.randn(vocab_size, hidden_size, device=device, requires_grad=True) @@ -300,7 +306,7 @@ def forward(self, x, target): ref_loss: torch.Tensor = ref_lm_head_ce(ref_input, target) liger_loss: torch.Tensor = liger_lm_head_ce(liger_input, target) - torch.testing.assert_close(liger_loss, ref_loss, rtol=1e-1, atol=1e-1) + torch.testing.assert_close(liger_loss, ref_loss, rtol=rtol, atol=atol) # Backward pass (backward() with reduction=="none" is not supported yet) if reduction == "none": @@ -309,16 +315,16 @@ def forward(self, x, target): liger_loss.backward() ref_loss.backward() - torch.testing.assert_close(liger_input.grad, ref_input.grad, rtol=1e-1, atol=1e-1) + torch.testing.assert_close(liger_input.grad, ref_input.grad, rtol=rtol, atol=atol) torch.testing.assert_close( - liger_lm_head_ce.lm_head.weight.grad, ref_lm_head_ce.lm_head.weight.grad, rtol=1e-1, atol=1e-1 + liger_lm_head_ce.lm_head.weight.grad, ref_lm_head_ce.lm_head.weight.grad, rtol=rtol, atol=atol ) # Benchmark - from helion._testing import run_example from functools import partial + from helion._testing import run_example def fwd_bwd_fn(input, target, fn): loss = fn(input, target) @@ -328,5 +334,5 @@ def fwd_bwd_fn(input, target, fn): ref_lm_head_ce_fwd_bwd = partial(fwd_bwd_fn, fn=ref_lm_head_ce) - run_example(liger_lm_head_ce, ref_lm_head_ce, (input, target), kernel_name="helion_flce_fwd", baseline_name="torch_fwd", rtol=1e-1, atol=1e-1) - run_example(liger_lm_head_ce_fwd_bwd, ref_lm_head_ce_fwd_bwd, (input, target), kernel_name="helion_flce_fwd_bwd", baseline_name="torch_fwd_bwd", rtol=1e-1, atol=1e-1) + run_example(liger_lm_head_ce, ref_lm_head_ce, (input, target), kernel_name="helion_flce_fwd", baseline_name="torch_fwd", rtol=rtol, atol=atol) + run_example(liger_lm_head_ce_fwd_bwd, ref_lm_head_ce_fwd_bwd, (input, target), kernel_name="helion_flce_fwd_bwd", baseline_name="torch_fwd_bwd", rtol=rtol, atol=atol) From f402ff3db5e4c7efb2897a6e375e0cab6d6b68f1 Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Thu, 6 Nov 2025 18:28:03 +0800 Subject: [PATCH 13/27] Add cut_cross_entropy comparison Signed-off-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> --- .../ops/helion/fused_linear_cross_entropy.py | 105 ++++++++++++++---- 1 file changed, 82 insertions(+), 23 deletions(-) diff --git a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py index 21d05146f..47284e19a 100644 --- a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py @@ -1,11 +1,34 @@ -import math - import helion import helion.language as hl import torch +from helion._testing import run_example + # Best config for default llama3Config(hidden_size=4096, vocab_size=32000) with 4096 bs*seqlen input -config = helion.Config(block_sizes=[32, 32, 256], indexing=['tensor_descriptor', 'tensor_descriptor', 'pointer', 'tensor_descriptor', 'pointer', 'pointer', 'tensor_descriptor', 'pointer', 'tensor_descriptor'], load_eviction_policies=['first', '', 'first', 'last', '', '', 'last', 'first'], num_stages=5, num_warps=4, pid_type='flat', range_flattens=[None, True, False], range_multi_buffers=[None, True, False], range_num_stages=[0, 0, 0], range_unroll_factors=[0, 1, 1], range_warp_specializes=[]) +config = helion.Config( + block_sizes=[32, 32, 256], + indexing=[ + "tensor_descriptor", + "tensor_descriptor", + "pointer", + "tensor_descriptor", + "pointer", + "pointer", + "tensor_descriptor", + "pointer", + "tensor_descriptor", + ], + load_eviction_policies=["first", "", "first", "last", "", "", "last", "first"], + num_stages=5, + num_warps=4, + pid_type="flat", + range_flattens=[None, True, False], + range_multi_buffers=[None, True, False], + range_num_stages=[0, 0, 0], + range_unroll_factors=[0, 1, 1], + range_warp_specializes=[], +) + def helion_lock_acquire(lock_ptr, lock_index): hl.inline_triton( @@ -17,6 +40,7 @@ def helion_lock_acquire(lock_ptr, lock_index): output_like=None, ) + def helion_lock_release(lock_ptr, lock_index): hl.inline_triton( """ @@ -99,14 +123,15 @@ def fused_linear_cross_entropy_fwd( nll[tile_bt] = nll_tile lse[tile_bt] = lse_tile - + if reduction != "none": - loss = nll.sum() + loss = nll.sum() else: loss = nll - + return loss, lse + @helion.kernel(autotune_effort="none", ignore_warnings=[helion.exc.TensorOperationInWrapper]) def fused_linear_cross_entropy_bwd( x: torch.Tensor, @@ -125,9 +150,9 @@ def fused_linear_cross_entropy_bwd( grad_w = torch.zeros_like(weight, dtype=torch.float32) n_non_ignore = (target != ignore_index).sum().unsqueeze(0) - num_block_bt = (BT + block_size_bt - 1)//block_size_bt - num_block_h = (H + block_size_h - 1)//block_size_h - num_block_v = (V + block_size_v - 1)//block_size_v + num_block_bt = (BT + block_size_bt - 1) // block_size_bt + num_block_h = (H + block_size_h - 1) // block_size_h + num_block_v = (V + block_size_v - 1) // block_size_v grad_x_lock = torch.zeros((num_block_bt, num_block_h), dtype=torch.int32, device=x.device) grad_w_lock = torch.zeros((num_block_v, num_block_h), dtype=torch.int32, device=x.device) # backward @@ -165,8 +190,8 @@ def fused_linear_cross_entropy_bwd( grad_x[tile_bt, tile_h] += partial_grad_x helion_lock_release(grad_x_lock, tile_bt.id * num_block_h + tile_h.id) # hl.atomic_add(grad_x, [tile_bt, tile_h], partial_grad_x) - - # for tile_h in hl.tile(H, block_size=block_size_h): + + # for tile_h in hl.tile(H, block_size=block_size_h): # grad_w = grad_logits.T[tile_v, tile_bt] @ x[tile_bt, tile_h] rhs_tile = x[tile_bt, tile_h] partial_grad_w = hl.dot(grad_logits_tile.T, rhs_tile, out_dtype=torch.float32) @@ -175,8 +200,6 @@ def fused_linear_cross_entropy_bwd( helion_lock_release(grad_w_lock, tile_v.id * num_block_h + tile_h.id) # hl.atomic_add(grad_w, [tile_v, tile_h], partial_grad_w) - - return grad_x, grad_w @@ -265,6 +288,28 @@ def forward(self, x, target): return self.flce(x, self.lm_head.weight, target) +from functools import partial + +from cut_cross_entropy import linear_cross_entropy + + +class CutLMHeadCE(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + ignore_index: int = -100, + reduction: str = "mean", + ): + super().__init__() + self.lm_head = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype) + self.flce = partial(linear_cross_entropy, ignore_index=ignore_index, reduction=reduction, return_lse=False) + + def forward(self, x, target): + return self.flce(x, self.lm_head.weight, target) + + if __name__ == "__main__": torch.manual_seed(0) torch.cuda.manual_seed(0) @@ -286,7 +331,7 @@ def forward(self, x, target): reduction = "mean" ignore_index = -100 rtol = 1e-2 - atol = 1e-2 + atol = 1e-1 input = torch.randn(batch_size * seq_len, hidden_size, device=device, requires_grad=True) weight = torch.randn(vocab_size, hidden_size, device=device, requires_grad=True) @@ -315,24 +360,38 @@ def forward(self, x, target): liger_loss.backward() ref_loss.backward() - torch.testing.assert_close(liger_input.grad, ref_input.grad, rtol=rtol, atol=atol) + torch.testing.assert_close(liger_input.grad, ref_input.grad, rtol=rtol, atol=atol * 10) torch.testing.assert_close( liger_lm_head_ce.lm_head.weight.grad, ref_lm_head_ce.lm_head.weight.grad, rtol=rtol, atol=atol ) - # Benchmark - from functools import partial - - from helion._testing import run_example + cce_lm_head_ce = CutLMHeadCE(hidden_size, vocab_size, dtype=dtype, reduction=reduction).to(device=device) + cce_lm_head_ce.lm_head.weight.data = weight.data def fwd_bwd_fn(input, target, fn): loss = fn(input, target) loss.backward() return loss + liger_lm_head_ce_fwd_bwd = partial(fwd_bwd_fn, fn=liger_lm_head_ce) ref_lm_head_ce_fwd_bwd = partial(fwd_bwd_fn, fn=ref_lm_head_ce) - - - run_example(liger_lm_head_ce, ref_lm_head_ce, (input, target), kernel_name="helion_flce_fwd", baseline_name="torch_fwd", rtol=rtol, atol=atol) - run_example(liger_lm_head_ce_fwd_bwd, ref_lm_head_ce_fwd_bwd, (input, target), kernel_name="helion_flce_fwd_bwd", baseline_name="torch_fwd_bwd", rtol=rtol, atol=atol) + cce_lm_head_ce_fwd_bwd = partial(fwd_bwd_fn, fn=cce_lm_head_ce) + + run_example( + liger_lm_head_ce, + {"torch_fwd": ref_lm_head_ce, "cce_fwd": cce_lm_head_ce}, + (input, target), + kernel_name="helion_flce_fwd", + rtol=rtol * 10, + atol=atol, + ) + if reduction != "none": + run_example( + liger_lm_head_ce_fwd_bwd, + {"torch_fwd_bwd": ref_lm_head_ce_fwd_bwd, "cce_fwd_bwd": cce_lm_head_ce_fwd_bwd}, + (input, target), + kernel_name="helion_flce_fwd_bwd", + rtol=rtol, + atol=atol, + ) From c1560228533cceed525846da5332011a1f815f4a Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Thu, 6 Nov 2025 18:50:47 +0800 Subject: [PATCH 14/27] Add LigerFusedLinearCrossEntropy for comparison Signed-off-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> --- .../ops/helion/fused_linear_cross_entropy.py | 36 +++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py index 47284e19a..28509cf94 100644 --- a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py @@ -309,6 +309,24 @@ def __init__( def forward(self, x, target): return self.flce(x, self.lm_head.weight, target) +from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss + + +class TritonLigerLMHeadCE(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + ignore_index: int = -100, + reduction: str = "mean", + ): + super().__init__() + self.lm_head = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype) + self.flce = LigerFusedLinearCrossEntropyLoss(ignore_index=ignore_index, reduction=reduction) + + def forward(self, x, target): + return self.flce(x, self.lm_head.weight, target, None) if __name__ == "__main__": torch.manual_seed(0) @@ -378,9 +396,19 @@ def fwd_bwd_fn(input, target, fn): ref_lm_head_ce_fwd_bwd = partial(fwd_bwd_fn, fn=ref_lm_head_ce) cce_lm_head_ce_fwd_bwd = partial(fwd_bwd_fn, fn=cce_lm_head_ce) + triton_liger_lm_head_ce = TritonLigerLMHeadCE(hidden_size, vocab_size, dtype=dtype, reduction=reduction).to( + device=device + ) + triton_liger_lm_head_ce.lm_head.weight.data = weight.data + triton_liger_lm_head_ce_fwd_bwd = partial(fwd_bwd_fn, fn=triton_liger_lm_head_ce) + run_example( liger_lm_head_ce, - {"torch_fwd": ref_lm_head_ce, "cce_fwd": cce_lm_head_ce}, + { + "torch_fwd": ref_lm_head_ce, + "cce_fwd": cce_lm_head_ce, + "triton_flce_fwd": triton_liger_lm_head_ce, + }, (input, target), kernel_name="helion_flce_fwd", rtol=rtol * 10, @@ -389,7 +417,11 @@ def fwd_bwd_fn(input, target, fn): if reduction != "none": run_example( liger_lm_head_ce_fwd_bwd, - {"torch_fwd_bwd": ref_lm_head_ce_fwd_bwd, "cce_fwd_bwd": cce_lm_head_ce_fwd_bwd}, + { + "torch_fwd_bwd": ref_lm_head_ce_fwd_bwd, + "cce_fwd_bwd": cce_lm_head_ce_fwd_bwd, + "triton_flce_fwd_bwd": triton_liger_lm_head_ce_fwd_bwd, + }, (input, target), kernel_name="helion_flce_fwd_bwd", rtol=rtol, From 719344dbfcf4336d0fdaa587ee94d1904391ea18 Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Thu, 6 Nov 2025 18:52:55 +0800 Subject: [PATCH 15/27] Add IMA error comment to liger flce Signed-off-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> --- src/liger_kernel/ops/helion/fused_linear_cross_entropy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py index 28509cf94..90744afc7 100644 --- a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py @@ -407,7 +407,7 @@ def fwd_bwd_fn(input, target, fn): { "torch_fwd": ref_lm_head_ce, "cce_fwd": cce_lm_head_ce, - "triton_flce_fwd": triton_liger_lm_head_ce, + "triton_flce_fwd": triton_liger_lm_head_ce, # this will ecounter illegal memory access error }, (input, target), kernel_name="helion_flce_fwd", @@ -420,7 +420,7 @@ def fwd_bwd_fn(input, target, fn): { "torch_fwd_bwd": ref_lm_head_ce_fwd_bwd, "cce_fwd_bwd": cce_lm_head_ce_fwd_bwd, - "triton_flce_fwd_bwd": triton_liger_lm_head_ce_fwd_bwd, + "triton_flce_fwd_bwd": triton_liger_lm_head_ce_fwd_bwd, # ditto }, (input, target), kernel_name="helion_flce_fwd_bwd", From 81d0e9807b5f5a75b699c66dcd1bb0535c07f4eb Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Thu, 6 Nov 2025 19:57:14 +0800 Subject: [PATCH 16/27] Fix incorrect liger flce args positions Signed-off-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> --- .../ops/helion/fused_linear_cross_entropy.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py index 90744afc7..36c7ce256 100644 --- a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py @@ -326,7 +326,7 @@ def __init__( self.flce = LigerFusedLinearCrossEntropyLoss(ignore_index=ignore_index, reduction=reduction) def forward(self, x, target): - return self.flce(x, self.lm_head.weight, target, None) + return self.flce(self.lm_head.weight, x, target, None) if __name__ == "__main__": torch.manual_seed(0) @@ -355,6 +355,9 @@ def forward(self, x, target): weight = torch.randn(vocab_size, hidden_size, device=device, requires_grad=True) target = torch.randint(0, vocab_size, (batch_size * seq_len,), device=device) + print(f"{input.shape=}") + print(f"{input.clone().detach().shape=}") + # Init ref_lm_head_ce = TorchLMHeadCE(hidden_size, vocab_size, dtype=dtype, reduction=reduction).to(device=device) liger_lm_head_ce = LigerLMHeadCE(hidden_size, vocab_size, dtype=dtype, reduction=reduction).to(device=device) @@ -402,12 +405,13 @@ def fwd_bwd_fn(input, target, fn): triton_liger_lm_head_ce.lm_head.weight.data = weight.data triton_liger_lm_head_ce_fwd_bwd = partial(fwd_bwd_fn, fn=triton_liger_lm_head_ce) + print(f"{input.shape=}") run_example( liger_lm_head_ce, { "torch_fwd": ref_lm_head_ce, "cce_fwd": cce_lm_head_ce, - "triton_flce_fwd": triton_liger_lm_head_ce, # this will ecounter illegal memory access error + "triton_flce_fwd": triton_liger_lm_head_ce, }, (input, target), kernel_name="helion_flce_fwd", @@ -420,7 +424,7 @@ def fwd_bwd_fn(input, target, fn): { "torch_fwd_bwd": ref_lm_head_ce_fwd_bwd, "cce_fwd_bwd": cce_lm_head_ce_fwd_bwd, - "triton_flce_fwd_bwd": triton_liger_lm_head_ce_fwd_bwd, # ditto + "triton_flce_fwd_bwd": triton_liger_lm_head_ce_fwd_bwd, }, (input, target), kernel_name="helion_flce_fwd_bwd", From fd05527b32b7a4aab7b68265b4631060f9f2631d Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Thu, 6 Nov 2025 20:10:52 +0800 Subject: [PATCH 17/27] Remove lock functions wrappers Signed-off-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> --- .../ops/helion/fused_linear_cross_entropy.py | 31 ++++--------------- 1 file changed, 6 insertions(+), 25 deletions(-) diff --git a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py index 36c7ce256..e0ae247ee 100644 --- a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py @@ -30,27 +30,6 @@ ) -def helion_lock_acquire(lock_ptr, lock_index): - hl.inline_triton( - """ - while tl.atomic_cas({0} + {1}, 0, 1, sem="acquire") == 1: - pass - """, - args=(lock_ptr, lock_index), - output_like=None, - ) - - -def helion_lock_release(lock_ptr, lock_index): - hl.inline_triton( - """ - tl.atomic_xchg({0} + {1}, 0, sem="release") - """, - args=(lock_ptr, lock_index), - output_like=None, - ) - - # @helion.kernel(config=config, ignore_warnings=[helion.exc.TensorOperationInWrapper]) @helion.kernel(autotune_effort="none", ignore_warnings=[helion.exc.TensorOperationInWrapper]) def fused_linear_cross_entropy_fwd( @@ -186,18 +165,20 @@ def fused_linear_cross_entropy_bwd( # grad_x = grad_logits @ weight rhs_tile = weight[tile_v, tile_h] partial_grad_x = hl.dot(grad_logits_tile, rhs_tile, out_dtype=torch.float32) - helion_lock_acquire(grad_x_lock, tile_bt.id * num_block_h + tile_h.id) + while hl.atomic_cas(grad_x_lock, [tile_bt.id, tile_h.id], 0, 1, sem="acquire") == 1: + pass grad_x[tile_bt, tile_h] += partial_grad_x - helion_lock_release(grad_x_lock, tile_bt.id * num_block_h + tile_h.id) + hl.atomic_xchg(grad_x_lock, [tile_bt.id, tile_h.id], 0, sem="release") # hl.atomic_add(grad_x, [tile_bt, tile_h], partial_grad_x) # for tile_h in hl.tile(H, block_size=block_size_h): # grad_w = grad_logits.T[tile_v, tile_bt] @ x[tile_bt, tile_h] rhs_tile = x[tile_bt, tile_h] partial_grad_w = hl.dot(grad_logits_tile.T, rhs_tile, out_dtype=torch.float32) - helion_lock_acquire(grad_w_lock, tile_v.id * num_block_h + tile_h.id) + while hl.atomic_cas(grad_w_lock, [tile_v.id, tile_h.id], 0, 1, sem="acquire") == 1: + pass grad_w[tile_v, tile_h] += partial_grad_w - helion_lock_release(grad_w_lock, tile_v.id * num_block_h + tile_h.id) + hl.atomic_xchg(grad_w_lock, [tile_v.id, tile_h.id], 0, sem="release") # hl.atomic_add(grad_w, [tile_v, tile_h], partial_grad_w) return grad_x, grad_w From 34decd43e90f14ef5aea8b64ff25b5119bd8da2d Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Thu, 6 Nov 2025 15:30:11 +0000 Subject: [PATCH 18/27] Update best configs for h100 with BT=2048, H=4096, V=32000 Signed-off-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> --- .../ops/helion/fused_linear_cross_entropy.py | 37 ++++--------------- 1 file changed, 8 insertions(+), 29 deletions(-) diff --git a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py index e0ae247ee..9c078daed 100644 --- a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py @@ -5,33 +5,10 @@ from helion._testing import run_example # Best config for default llama3Config(hidden_size=4096, vocab_size=32000) with 4096 bs*seqlen input -config = helion.Config( - block_sizes=[32, 32, 256], - indexing=[ - "tensor_descriptor", - "tensor_descriptor", - "pointer", - "tensor_descriptor", - "pointer", - "pointer", - "tensor_descriptor", - "pointer", - "tensor_descriptor", - ], - load_eviction_policies=["first", "", "first", "last", "", "", "last", "first"], - num_stages=5, - num_warps=4, - pid_type="flat", - range_flattens=[None, True, False], - range_multi_buffers=[None, True, False], - range_num_stages=[0, 0, 0], - range_unroll_factors=[0, 1, 1], - range_warp_specializes=[], -) - - -# @helion.kernel(config=config, ignore_warnings=[helion.exc.TensorOperationInWrapper]) -@helion.kernel(autotune_effort="none", ignore_warnings=[helion.exc.TensorOperationInWrapper]) +h100_fwd_config=helion.Config(block_sizes=[64, 32, 512], indexing=['pointer', 'pointer', 'tensor_descriptor', 'pointer', 'tensor_descriptor', 'tensor_descriptor'], load_eviction_policies=['', 'last', 'last', 'last'], num_stages=8, num_warps=16, pid_type='flat', range_flattens=[None, False, None], range_multi_buffers=[None, True, False], range_num_stages=[0, 3, 3], range_unroll_factors=[0, 0, 1], range_warp_specializes=[]) + +@helion.kernel(config=h100_fwd_config, static_shapes=True) +# @helion.kernel(autotune_effort="quick", ignore_warnings=[helion.exc.TensorOperationInWrapper]) def fused_linear_cross_entropy_fwd( x: torch.Tensor, weight: torch.Tensor, @@ -111,7 +88,9 @@ def fused_linear_cross_entropy_fwd( return loss, lse -@helion.kernel(autotune_effort="none", ignore_warnings=[helion.exc.TensorOperationInWrapper]) +h100_bwd_config = helion.Config(block_sizes=[128, 64, 128], indexing=['pointer', 'pointer', 'pointer', 'tensor_descriptor', 'tensor_descriptor', 'pointer', 'tensor_descriptor', 'tensor_descriptor', 'pointer', 'tensor_descriptor', 'tensor_descriptor'], l2_groupings=[64], load_eviction_policies=['', 'first', 'last', '', '', 'last', 'first', 'first', ''], loop_orders=[[0, 1]], num_stages=7, num_warps=8, pid_type='flat', range_flattens=[None, True], range_multi_buffers=[None, None], range_num_stages=[0, 3], range_unroll_factors=[0, 1], range_warp_specializes=[]) +@helion.kernel(config=h100_bwd_config, static_shapes=True) +# @helion.kernel(autotune_effort="quick", ignore_warnings=[helion.exc.TensorOperationInWrapper]) def fused_linear_cross_entropy_bwd( x: torch.Tensor, weight: torch.Tensor, @@ -319,7 +298,7 @@ def forward(self, x, target): device = "cuda" batch_size = 2 - seq_len = 1024 + seq_len = 2048 hidden_size = 4096 vocab_size = 32000 # batch_size = 2 From 69053909e285608a30391e5bca5ee7f3f581c553 Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Thu, 6 Nov 2025 16:03:54 +0000 Subject: [PATCH 19/27] Clean up handwriting test and let run_example() handle correctness test Signed-off-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> --- .../ops/helion/fused_linear_cross_entropy.py | 57 ++++++------------- 1 file changed, 16 insertions(+), 41 deletions(-) diff --git a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py index 9c078daed..26a30086a 100644 --- a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py @@ -298,13 +298,12 @@ def forward(self, x, target): device = "cuda" batch_size = 2 - seq_len = 2048 - hidden_size = 4096 - vocab_size = 32000 - # batch_size = 2 - # seq_len = 256 - # hidden_size = 512 - # vocab_size = 1024 + seq_len = 4096 + hidden_size = 2304 + vocab_size = 262208 + + print(f"BT={batch_size * seq_len}, H={hidden_size}, V={vocab_size}") + dtype = torch.float32 reduction = "mean" ignore_index = -100 @@ -315,41 +314,19 @@ def forward(self, x, target): weight = torch.randn(vocab_size, hidden_size, device=device, requires_grad=True) target = torch.randint(0, vocab_size, (batch_size * seq_len,), device=device) - print(f"{input.shape=}") - print(f"{input.clone().detach().shape=}") - # Init ref_lm_head_ce = TorchLMHeadCE(hidden_size, vocab_size, dtype=dtype, reduction=reduction).to(device=device) liger_lm_head_ce = LigerLMHeadCE(hidden_size, vocab_size, dtype=dtype, reduction=reduction).to(device=device) + cce_lm_head_ce = CutLMHeadCE(hidden_size, vocab_size, dtype=dtype, reduction=reduction).to(device=device) + triton_liger_lm_head_ce = TritonLigerLMHeadCE(hidden_size, vocab_size, dtype=dtype, reduction=reduction).to( + device=device + ) ref_lm_head_ce.lm_head.weight.data = weight.data liger_lm_head_ce.lm_head.weight.data = weight.data - - ref_input = input.detach().clone().requires_grad_(True) - liger_input = input.detach().clone().requires_grad_(True) - - # Forward pass - ref_loss: torch.Tensor = ref_lm_head_ce(ref_input, target) - liger_loss: torch.Tensor = liger_lm_head_ce(liger_input, target) - - torch.testing.assert_close(liger_loss, ref_loss, rtol=rtol, atol=atol) - - # Backward pass (backward() with reduction=="none" is not supported yet) - if reduction == "none": - pass - else: - liger_loss.backward() - ref_loss.backward() - - torch.testing.assert_close(liger_input.grad, ref_input.grad, rtol=rtol, atol=atol * 10) - torch.testing.assert_close( - liger_lm_head_ce.lm_head.weight.grad, ref_lm_head_ce.lm_head.weight.grad, rtol=rtol, atol=atol - ) - - # Benchmark - cce_lm_head_ce = CutLMHeadCE(hidden_size, vocab_size, dtype=dtype, reduction=reduction).to(device=device) cce_lm_head_ce.lm_head.weight.data = weight.data - + triton_liger_lm_head_ce.lm_head.weight.data = weight.data + def fwd_bwd_fn(input, target, fn): loss = fn(input, target) loss.backward() @@ -358,14 +335,12 @@ def fwd_bwd_fn(input, target, fn): liger_lm_head_ce_fwd_bwd = partial(fwd_bwd_fn, fn=liger_lm_head_ce) ref_lm_head_ce_fwd_bwd = partial(fwd_bwd_fn, fn=ref_lm_head_ce) cce_lm_head_ce_fwd_bwd = partial(fwd_bwd_fn, fn=cce_lm_head_ce) - - triton_liger_lm_head_ce = TritonLigerLMHeadCE(hidden_size, vocab_size, dtype=dtype, reduction=reduction).to( - device=device - ) - triton_liger_lm_head_ce.lm_head.weight.data = weight.data triton_liger_lm_head_ce_fwd_bwd = partial(fwd_bwd_fn, fn=triton_liger_lm_head_ce) + + # Test and Benchmark + + - print(f"{input.shape=}") run_example( liger_lm_head_ce, { From b508243a3311896bacc4fd950effc19e91b548a2 Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Sat, 8 Nov 2025 15:00:47 +0800 Subject: [PATCH 20/27] Add chunk version of flce backward Signed-off-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> --- .../ops/helion/fused_linear_cross_entropy.py | 211 +++++++++++++++--- 1 file changed, 178 insertions(+), 33 deletions(-) diff --git a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py index 26a30086a..12d8a9b36 100644 --- a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py @@ -5,10 +5,23 @@ from helion._testing import run_example # Best config for default llama3Config(hidden_size=4096, vocab_size=32000) with 4096 bs*seqlen input -h100_fwd_config=helion.Config(block_sizes=[64, 32, 512], indexing=['pointer', 'pointer', 'tensor_descriptor', 'pointer', 'tensor_descriptor', 'tensor_descriptor'], load_eviction_policies=['', 'last', 'last', 'last'], num_stages=8, num_warps=16, pid_type='flat', range_flattens=[None, False, None], range_multi_buffers=[None, True, False], range_num_stages=[0, 3, 3], range_unroll_factors=[0, 0, 1], range_warp_specializes=[]) - -@helion.kernel(config=h100_fwd_config, static_shapes=True) -# @helion.kernel(autotune_effort="quick", ignore_warnings=[helion.exc.TensorOperationInWrapper]) +h100_fwd_config = helion.Config( + block_sizes=[64, 32, 512], + indexing=["pointer", "pointer", "tensor_descriptor", "pointer", "tensor_descriptor", "tensor_descriptor"], + load_eviction_policies=["", "last", "last", "last"], + num_stages=8, + num_warps=16, + pid_type="flat", + range_flattens=[None, False, None], + range_multi_buffers=[None, True, False], + range_num_stages=[0, 3, 3], + range_unroll_factors=[0, 0, 1], + range_warp_specializes=[], +) + + +# @helion.kernel(config=h100_fwd_config, static_shapes=True) +@helion.kernel(autotune_effort="none", ignore_warnings=[helion.exc.TensorOperationInWrapper]) def fused_linear_cross_entropy_fwd( x: torch.Tensor, weight: torch.Tensor, @@ -88,9 +101,37 @@ def fused_linear_cross_entropy_fwd( return loss, lse -h100_bwd_config = helion.Config(block_sizes=[128, 64, 128], indexing=['pointer', 'pointer', 'pointer', 'tensor_descriptor', 'tensor_descriptor', 'pointer', 'tensor_descriptor', 'tensor_descriptor', 'pointer', 'tensor_descriptor', 'tensor_descriptor'], l2_groupings=[64], load_eviction_policies=['', 'first', 'last', '', '', 'last', 'first', 'first', ''], loop_orders=[[0, 1]], num_stages=7, num_warps=8, pid_type='flat', range_flattens=[None, True], range_multi_buffers=[None, None], range_num_stages=[0, 3], range_unroll_factors=[0, 1], range_warp_specializes=[]) -@helion.kernel(config=h100_bwd_config, static_shapes=True) -# @helion.kernel(autotune_effort="quick", ignore_warnings=[helion.exc.TensorOperationInWrapper]) +h100_bwd_config = helion.Config( + block_sizes=[128, 64, 128], + indexing=[ + "pointer", + "pointer", + "pointer", + "tensor_descriptor", + "tensor_descriptor", + "pointer", + "tensor_descriptor", + "tensor_descriptor", + "pointer", + "tensor_descriptor", + "tensor_descriptor", + ], + l2_groupings=[64], + load_eviction_policies=["", "first", "last", "", "", "last", "first", "first", ""], + loop_orders=[[0, 1]], + num_stages=7, + num_warps=8, + pid_type="flat", + range_flattens=[None, True], + range_multi_buffers=[None, None], + range_num_stages=[0, 3], + range_unroll_factors=[0, 1], + range_warp_specializes=[], +) + + +# @helion.kernel(config=h100_bwd_config, static_shapes=True) +@helion.kernel(autotune_effort="none", ignore_warnings=[helion.exc.TensorOperationInWrapper]) def fused_linear_cross_entropy_bwd( x: torch.Tensor, weight: torch.Tensor, @@ -163,16 +204,98 @@ def fused_linear_cross_entropy_bwd( return grad_x, grad_w +@helion.kernel(autotune_effort="none", ignore_warnings=[helion.exc.TensorOperationInWrapper]) +def _grad_logit_compute( + x: torch.Tensor, + weight: torch.Tensor, + target: torch.Tensor, + lse: torch.Tensor, + n_non_ignore: torch.Tensor, + reduction: str = "mean", +): + BT, H = x.size() + V = weight.size(0) + + block_size_bt = hl.register_block_size(BT) + block_size_h = hl.register_block_size(H) + block_size_v = hl.register_block_size(V) + grad_logits = torch.zeros((BT, V), dtype=torch.float32, device=x.device) + for tile_bt, tile_v in hl.tile([BT, V], block_size=(block_size_bt, block_size_v)): + if reduction == "mean": + n_non_ignore_value = hl.load(n_non_ignore, [0], eviction_policy="evict_last") + # Restore logits + acc2 = hl.zeros([tile_bt, tile_v], dtype=torch.float32) + for tile_h in hl.tile(H, block_size=block_size_h): + x_tile = x[tile_bt, tile_h] + weight_tile = weight[tile_v, tile_h] + acc2 = hl.dot(x_tile, weight_tile.T, acc=acc2, out_dtype=torch.float32) + + # softmax(x_i) = exp(x_i) / sum(exp(x_i)) + # = exp(x_i) / log(exp(sum(x_i))) + # = exp(x_i) / lse = exp(x_i - lse) + lse_tile = lse[tile_bt] + target_indices = target[tile_bt].unsqueeze(1) # [tile_bt, 1] + + grad_logits_tile = torch.exp(acc2 - lse_tile[:, None]) + offset = tile_v.index.unsqueeze(0) # [1, tile_v] + mask = target_indices == offset # [tile_bt, tile_v] + grad_logits_tile = grad_logits_tile - mask.float() + # handle out of bound values in grad_logits_tile + grad_logits_tile = grad_logits_tile * ((tile_bt.index < BT)[:, None] & (tile_v.index < V)[None, :]) + + if reduction == "mean": + grad_logits_tile /= n_non_ignore_value + + grad_logits[tile_bt, tile_v] = grad_logits_tile + return grad_logits + + +def fused_linear_cross_entropy_bwd_chunk( + x: torch.Tensor, + weight: torch.Tensor, + target: torch.Tensor, + lse: torch.Tensor, + ignore_index: int = -100, + reduction: str = "mean", +): + BT, H = x.size() + V = weight.size(0) + # for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be: + # inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor + # for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048 + num_chunks = (V + H - 1) // H + chunk_size = (BT + num_chunks - 1) // num_chunks + grad_x = torch.zeros_like(x, dtype=torch.float32) + grad_w = torch.zeros_like(weight, dtype=torch.float32) + n_non_ignore = (target != ignore_index).sum().unsqueeze(0) + + x_chunks = torch.chunk(x, chunks=num_chunks, dim=0) + lse_chunks = torch.chunk(lse, chunks=num_chunks, dim=0) + target_chunks = torch.chunk(target, chunks=num_chunks, dim=0) + + for chunk_id, (x_chunk, target_chunk, lse_chunk) in enumerate(zip(x_chunks, target_chunks, lse_chunks)): + start_idx = chunk_id * chunk_size + end_idx = min((chunk_id + 1) * chunk_size, BT) + + grad_logits_chunk = _grad_logit_compute( + x_chunk, + weight, + target_chunk, + lse_chunk, + n_non_ignore, + reduction, + ) + + grad_x[start_idx:end_idx] = grad_logits_chunk @ weight + grad_w += torch.mm(grad_logits_chunk.T, x_chunk).float() + + return grad_x, grad_w + + class LigerFusedLinearCrossEntropyHelionFunction(torch.autograd.Function): @staticmethod - def forward( - ctx, - _input, - weight, - target, - ignore_index=-100, - reduction="mean", - ): + def forward(ctx, _input, weight, target, ignore_index=-100, reduction="mean", bwd_impl="chunk"): + assert bwd_impl in ["chunk", "cce"] loss, lse = fused_linear_cross_entropy_fwd( _input, weight, @@ -182,6 +305,7 @@ def forward( ) ctx.ignore_index = ignore_index ctx.reduction = reduction + ctx.bwd_impl = bwd_impl ctx.save_for_backward(_input, lse) return loss @@ -189,7 +313,11 @@ def forward( def backward(ctx, grad_output): assert grad_output.ndim == 0, "token_scaling is not supported. grad_output must be a scalar" _input, lse = ctx.saved_tensors - grad_input, grad_weight = fused_linear_cross_entropy_bwd( + if ctx.bwd_impl == "cce": + bwd_fn = fused_linear_cross_entropy_bwd + elif ctx.bwd_impl == "chunk": + bwd_fn = fused_linear_cross_entropy_bwd_chunk + grad_input, grad_weight = bwd_fn( _input, weight, target, @@ -197,18 +325,19 @@ def backward(ctx, grad_output): ctx.ignore_index, ctx.reduction, ) - return grad_input * grad_output, grad_weight * grad_output, None, None, None + return grad_input * grad_output, grad_weight * grad_output, None, None, None, None class LigerFusedLinearCrossEntropyHelion(torch.nn.Module): - def __init__(self, ignore_index=-100, reduction="mean"): + def __init__(self, ignore_index=-100, reduction="mean", bwd_impl="chunk"): super().__init__() self.ignore_index = ignore_index self.reduction = reduction + self.bwd_impl = bwd_impl def forward(self, _input, weight, target): return LigerFusedLinearCrossEntropyHelionFunction.apply( - _input, weight, target, self.ignore_index, self.reduction + _input, weight, target, self.ignore_index, self.reduction, self.bwd_impl ) @@ -239,10 +368,13 @@ def __init__( dtype: torch.dtype, ignore_index: int = -100, reduction: str = "mean", + bwd_impl: str = "cce", ): super().__init__() self.lm_head = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype) - self.flce = LigerFusedLinearCrossEntropyHelion(ignore_index=ignore_index, reduction=reduction) + self.flce = LigerFusedLinearCrossEntropyHelion( + ignore_index=ignore_index, reduction=reduction, bwd_impl=bwd_impl + ) def forward(self, x, target): return self.flce(x, self.lm_head.weight, target) @@ -269,6 +401,7 @@ def __init__( def forward(self, x, target): return self.flce(x, self.lm_head.weight, target) + from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss @@ -288,6 +421,7 @@ def __init__( def forward(self, x, target): return self.flce(self.lm_head.weight, x, target, None) + if __name__ == "__main__": torch.manual_seed(0) torch.cuda.manual_seed(0) @@ -297,13 +431,17 @@ def forward(self, x, target): device = "cuda" + # batch_size = 2 + # seq_len = 4096 + # hidden_size = 2304 + # vocab_size = 262208 batch_size = 2 - seq_len = 4096 - hidden_size = 2304 - vocab_size = 262208 + seq_len = 1024 + hidden_size = 512 + vocab_size = 1024 print(f"BT={batch_size * seq_len}, H={hidden_size}, V={vocab_size}") - + dtype = torch.float32 reduction = "mean" ignore_index = -100 @@ -316,7 +454,12 @@ def forward(self, x, target): # Init ref_lm_head_ce = TorchLMHeadCE(hidden_size, vocab_size, dtype=dtype, reduction=reduction).to(device=device) - liger_lm_head_ce = LigerLMHeadCE(hidden_size, vocab_size, dtype=dtype, reduction=reduction).to(device=device) + liger_lm_head_ce = LigerLMHeadCE(hidden_size, vocab_size, dtype=dtype, reduction=reduction, bwd_impl="cce").to( + device=device + ) + liger_chunk_lm_head_ce = LigerLMHeadCE( + hidden_size, vocab_size, dtype=dtype, reduction=reduction, bwd_impl="chunk" + ).to(device=device) cce_lm_head_ce = CutLMHeadCE(hidden_size, vocab_size, dtype=dtype, reduction=reduction).to(device=device) triton_liger_lm_head_ce = TritonLigerLMHeadCE(hidden_size, vocab_size, dtype=dtype, reduction=reduction).to( device=device @@ -324,45 +467,47 @@ def forward(self, x, target): ref_lm_head_ce.lm_head.weight.data = weight.data liger_lm_head_ce.lm_head.weight.data = weight.data + liger_chunk_lm_head_ce.lm_head.weight.data = weight.data cce_lm_head_ce.lm_head.weight.data = weight.data triton_liger_lm_head_ce.lm_head.weight.data = weight.data - + def fwd_bwd_fn(input, target, fn): loss = fn(input, target) loss.backward() return loss liger_lm_head_ce_fwd_bwd = partial(fwd_bwd_fn, fn=liger_lm_head_ce) + liger_chunk_lm_head_ce_fwd_bwd = partial(fwd_bwd_fn, fn=liger_chunk_lm_head_ce) ref_lm_head_ce_fwd_bwd = partial(fwd_bwd_fn, fn=ref_lm_head_ce) cce_lm_head_ce_fwd_bwd = partial(fwd_bwd_fn, fn=cce_lm_head_ce) triton_liger_lm_head_ce_fwd_bwd = partial(fwd_bwd_fn, fn=triton_liger_lm_head_ce) - - # Test and Benchmark - + # Test and Benchmark run_example( liger_lm_head_ce, { "torch_fwd": ref_lm_head_ce, "cce_fwd": cce_lm_head_ce, - "triton_flce_fwd": triton_liger_lm_head_ce, + "triton_flce_fwd": triton_liger_lm_head_ce, }, (input, target), - kernel_name="helion_flce_fwd", + kernel_name="helion_fwd", rtol=rtol * 10, atol=atol, ) if reduction != "none": run_example( - liger_lm_head_ce_fwd_bwd, + { + "helion_fwd_bwd_cce": liger_lm_head_ce_fwd_bwd, + "helion_fwd_bwd_chunk": liger_chunk_lm_head_ce_fwd_bwd, + }, { "torch_fwd_bwd": ref_lm_head_ce_fwd_bwd, "cce_fwd_bwd": cce_lm_head_ce_fwd_bwd, "triton_flce_fwd_bwd": triton_liger_lm_head_ce_fwd_bwd, }, (input, target), - kernel_name="helion_flce_fwd_bwd", rtol=rtol, atol=atol, ) From d901dac18596c02c559cd1a58e196f447fa163a5 Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Sat, 8 Nov 2025 19:37:53 +0800 Subject: [PATCH 21/27] Add autotune misc Signed-off-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> --- .../ops/helion/fused_linear_cross_entropy.py | 91 ++++++++++++++++++- 1 file changed, 90 insertions(+), 1 deletion(-) diff --git a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py index 12d8a9b36..ba9ec6b1c 100644 --- a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py @@ -1,3 +1,5 @@ +import argparse + import helion import helion.language as hl import torch @@ -93,6 +95,7 @@ def fused_linear_cross_entropy_fwd( nll[tile_bt] = nll_tile lse[tile_bt] = lse_tile + if reduction != "none": loss = nll.sum() else: @@ -178,6 +181,7 @@ def fused_linear_cross_entropy_bwd( # handle out of bound values in grad_logits_tile grad_logits_tile = grad_logits_tile * ((tile_bt.index < BT)[:, None] & (tile_v.index < V)[None, :]) + if reduction == "mean": grad_logits_tile /= n_non_ignore_value @@ -243,9 +247,12 @@ def _grad_logit_compute( # handle out of bound values in grad_logits_tile grad_logits_tile = grad_logits_tile * ((tile_bt.index < BT)[:, None] & (tile_v.index < V)[None, :]) + if reduction == "mean": grad_logits_tile /= n_non_ignore_value + + grad_logits[tile_bt, tile_v] = grad_logits_tile return grad_logits @@ -421,15 +428,77 @@ def __init__( def forward(self, x, target): return self.flce(self.lm_head.weight, x, target, None) +def autotune_kernels(model_config_dataset): + + def generate_flce_fwd_input(BT, V, H, dtype): + x = torch.randn(BT, H, device=device, dtype=dtype) + weight = torch.randn(V, H, device=device, dtype=dtype) + target = torch.randint(0, V, (BT,), device=device) + return (x, weight, target) + + for model_name, model_config in model_config_dataset.items(): + for dtype in [torch.bfloat16, torch.float32]: + BT = 4096 + args = generate_flce_fwd_input(BT, model_config["hidden_size"], model_config["vocab_size"]) + config = fused_linear_cross_entropy_fwd.autotune(args) + if dtype == torch.bfloat16: + dtype_str = "bf16" + elif dtype == torch.float32: + dtype_str = "fp32" + config.save(f"configs/fused_linear_cross_entropy_fwd_{model_name}_{dtype_str}.json") + + def generate_flce_bwd_input(BT, V, H, dtype): + x = torch.randn(BT, H, device=device, dtype=dtype) + weight = torch.randn(V, H, device=device, dtype=dtype) + target = torch.randint(0, V, (BT,), device=device) + lse = torch.randn(BT, device=device, dtype=torch.float32) + return (x, weight, target, lse) + + for model_name, model_config in model_config_dataset.items(): + for dtype in [torch.bfloat16, torch.float32]: + BT = 4096 + args = generate_flce_bwd_input(BT, model_config["hidden_size"], model_config["vocab_size"]) + config = fused_linear_cross_entropy_bwd.autotune(args) + if dtype == torch.bfloat16: + dtype_str = "bf16" + elif dtype == torch.float32: + dtype_str = "fp32" + config.save(f"configs/fused_linear_cross_entropy_bwd_{model_name}_{dtype_str}.json") + + def generate_grad_logits_compute_input(BT, V, H, dtype): + x = torch.randn(BT, H, device=device, dtype=dtype) + weight = torch.randn(V, H, device=device, dtype=dtype) + target = torch.randint(0, V, (BT,), device=device) + lse = torch.randn(BT, device=device, dtype=torch.float32) + n_non_ignore = (target != ignore_index).sum().unsqueeze(0) + return (x, weight, target, lse, n_non_ignore) + + for model_name, model_config in model_config_dataset.items(): + for dtype in [torch.bfloat16, torch.float32]: + BT = 4096 + args = generate_grad_logits_compute_input(BT, model_config["hidden_size"], model_config["vocab_size"]) + config = _grad_logit_compute.autotune(args) + if dtype == torch.bfloat16: + dtype_str = "bf16" + elif dtype == torch.float32: + dtype_str = "fp32" + config.save(f"configs/_grad_logit_compute_{model_name}_{dtype_str}.json") if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--autotune', default=False) + args = parser.parse_args() + torch.manual_seed(0) torch.cuda.manual_seed(0) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.enabled = True - device = "cuda" + from liger_kernel.utils import infer_device + device = infer_device() + torch_device = getattr(torch, device) + gpu_name = torch_device.get_device_name(torch_device.current_device()) # batch_size = 2 # seq_len = 4096 @@ -452,6 +521,26 @@ def forward(self, x, target): weight = torch.randn(vocab_size, hidden_size, device=device, requires_grad=True) target = torch.randint(0, vocab_size, (batch_size * seq_len,), device=device) + model_config_dataset = { + "llama": { + "hidden_size": 4096, + "vocab_size": 32000, + }, + "gemma3": { + "hidden_size": 2305, + "vocab_size": 262208, + }, + "qwen3": { + "hidden_size": 4096, + "vocab_size": 151936, + }, + } + + if args.autotune: + print("autotuning all kernels...") + autotune_kernels(model_config_dataset=model_config_dataset) + + # Init ref_lm_head_ce = TorchLMHeadCE(hidden_size, vocab_size, dtype=dtype, reduction=reduction).to(device=device) liger_lm_head_ce = LigerLMHeadCE(hidden_size, vocab_size, dtype=dtype, reduction=reduction, bwd_impl="cce").to( From 5ecb417e60c39c59a9c3b9d7e744c35aaf4d7812 Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Sat, 8 Nov 2025 19:38:12 +0800 Subject: [PATCH 22/27] Fix ignore_index Signed-off-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> --- .../ops/helion/fused_linear_cross_entropy.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py index ba9ec6b1c..20e8a2795 100644 --- a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py @@ -89,6 +89,9 @@ def fused_linear_cross_entropy_fwd( lse_tile = m_i + torch.log(d_i) nll_tile = nll_tile + lse_tile + # handle ignore index + nll_tile = nll_tile * (target_indices.ravel() != ignore_index) + if reduction == "mean": nll_tile /= n_non_ignore_value @@ -101,7 +104,7 @@ def fused_linear_cross_entropy_fwd( else: loss = nll - return loss, lse + return loss.to(x.dtype), lse h100_bwd_config = helion.Config( @@ -159,6 +162,8 @@ def fused_linear_cross_entropy_bwd( grad_w_lock = torch.zeros((num_block_v, num_block_h), dtype=torch.int32, device=x.device) # backward for tile_bt, tile_v in hl.tile([BT, V], block_size=(block_size_bt, block_size_v)): + if reduction == "mean": + n_non_ignore_value = hl.load(n_non_ignore, [0], eviction_policy="evict_last") # Restore logits acc2 = hl.zeros([tile_bt, tile_v], dtype=torch.float32) for tile_h in hl.tile(H, block_size=block_size_h): @@ -171,8 +176,6 @@ def fused_linear_cross_entropy_bwd( # = exp(x_i) / lse = exp(x_i - lse) lse_tile = lse[tile_bt] target_indices = target[tile_bt].unsqueeze(1) # [tile_bt, 1] - if reduction == "mean": - n_non_ignore_value = hl.load(n_non_ignore, [0]) grad_logits_tile = torch.exp(acc2 - lse_tile[:, None]) offset = tile_v.index.unsqueeze(0) # [1, tile_v] @@ -181,6 +184,8 @@ def fused_linear_cross_entropy_bwd( # handle out of bound values in grad_logits_tile grad_logits_tile = grad_logits_tile * ((tile_bt.index < BT)[:, None] & (tile_v.index < V)[None, :]) + # handle ignore index + grad_logits_tile = grad_logits_tile * (target_indices != ignore_index) if reduction == "mean": grad_logits_tile /= n_non_ignore_value @@ -205,7 +210,7 @@ def fused_linear_cross_entropy_bwd( hl.atomic_xchg(grad_w_lock, [tile_v.id, tile_h.id], 0, sem="release") # hl.atomic_add(grad_w, [tile_v, tile_h], partial_grad_w) - return grad_x, grad_w + return grad_x.to(x.dtype), grad_w.to(x.dtype) @helion.kernel(autotune_effort="none", ignore_warnings=[helion.exc.TensorOperationInWrapper]) @@ -215,6 +220,7 @@ def _grad_logit_compute( target: torch.Tensor, lse: torch.Tensor, n_non_ignore: torch.Tensor, + ignore_index: int = -100, reduction: str = "mean", ): BT, H = x.size() @@ -247,6 +253,8 @@ def _grad_logit_compute( # handle out of bound values in grad_logits_tile grad_logits_tile = grad_logits_tile * ((tile_bt.index < BT)[:, None] & (tile_v.index < V)[None, :]) + # handle ignore index + grad_logits_tile = grad_logits_tile * (target_indices != ignore_index) if reduction == "mean": grad_logits_tile /= n_non_ignore_value @@ -296,7 +304,7 @@ def fused_linear_cross_entropy_bwd_chunk( grad_x[start_idx:end_idx] = grad_logits_chunk @ weight grad_w += torch.mm(grad_logits_chunk.T, x_chunk).float() - return grad_x, grad_w + return grad_x.to(x.dtype), grad_w.to(x.dtype) class LigerFusedLinearCrossEntropyHelionFunction(torch.autograd.Function): From 74553dc50c5bf1fd633d2eb858ec1f3a54e2d858 Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Sat, 8 Nov 2025 19:55:53 +0800 Subject: [PATCH 23/27] Fix backward ctx.savetensors Signed-off-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> --- src/liger_kernel/ops/helion/fused_linear_cross_entropy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py index 20e8a2795..aef9bc7db 100644 --- a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py @@ -321,13 +321,13 @@ def forward(ctx, _input, weight, target, ignore_index=-100, reduction="mean", bw ctx.ignore_index = ignore_index ctx.reduction = reduction ctx.bwd_impl = bwd_impl - ctx.save_for_backward(_input, lse) + ctx.save_for_backward(_input, lse, weight, target) return loss @staticmethod def backward(ctx, grad_output): assert grad_output.ndim == 0, "token_scaling is not supported. grad_output must be a scalar" - _input, lse = ctx.saved_tensors + _input, lse, weight, target = ctx.saved_tensors if ctx.bwd_impl == "cce": bwd_fn = fused_linear_cross_entropy_bwd elif ctx.bwd_impl == "chunk": From d2b2372b529784544d784131883e875e99f0ebc1 Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Sat, 8 Nov 2025 20:04:50 +0800 Subject: [PATCH 24/27] clean up Signed-off-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> --- .../ops/helion/fused_linear_cross_entropy.py | 165 ++++++++++-------- 1 file changed, 94 insertions(+), 71 deletions(-) diff --git a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py index aef9bc7db..b3c1b6e56 100644 --- a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py @@ -6,6 +6,8 @@ from helion._testing import run_example +from liger_kernel.utils import infer_device + # Best config for default llama3Config(hidden_size=4096, vocab_size=32000) with 4096 bs*seqlen input h100_fwd_config = helion.Config( block_sizes=[64, 32, 512], @@ -98,7 +100,6 @@ def fused_linear_cross_entropy_fwd( nll[tile_bt] = nll_tile lse[tile_bt] = lse_tile - if reduction != "none": loss = nll.sum() else: @@ -259,8 +260,6 @@ def _grad_logit_compute( if reduction == "mean": grad_logits_tile /= n_non_ignore_value - - grad_logits[tile_bt, tile_v] = grad_logits_tile return grad_logits @@ -436,86 +435,96 @@ def __init__( def forward(self, x, target): return self.flce(self.lm_head.weight, x, target, None) -def autotune_kernels(model_config_dataset): - def generate_flce_fwd_input(BT, V, H, dtype): - x = torch.randn(BT, H, device=device, dtype=dtype) - weight = torch.randn(V, H, device=device, dtype=dtype) - target = torch.randint(0, V, (BT,), device=device) - return (x, weight, target) +def generate_flce_fwd_input(BT, V, H, dtype, device): + x = torch.randn(BT, H, device=device, dtype=dtype) + weight = torch.randn(V, H, device=device, dtype=dtype) + target = torch.randint(0, V, (BT,), device=device) + return (x, weight, target) + + +def generate_flce_bwd_input(BT, V, H, dtype, device): + x = torch.randn(BT, H, device=device, dtype=dtype) + weight = torch.randn(V, H, device=device, dtype=dtype) + target = torch.randint(0, V, (BT,), device=device) + lse = torch.randn(BT, device=device, dtype=torch.float32) + return (x, weight, target, lse) + + +def generate_grad_logits_compute_input(BT, V, H, dtype, device): + x = torch.randn(BT, H, device=device, dtype=dtype) + weight = torch.randn(V, H, device=device, dtype=dtype) + target = torch.randint(0, V, (BT,), device=device) + lse = torch.randn(BT, device=device, dtype=torch.float32) + n_non_ignore = (target != -100).sum().unsqueeze(0) + return (x, weight, target, lse, n_non_ignore) + + +def autotune_kernels(model_config_dataset): + device = infer_device() + torch_device = getattr(torch, device) + gpu_name = torch_device.get_device_name(torch_device.current_device()) for model_name, model_config in model_config_dataset.items(): for dtype in [torch.bfloat16, torch.float32]: BT = 4096 - args = generate_flce_fwd_input(BT, model_config["hidden_size"], model_config["vocab_size"]) + args = generate_flce_fwd_input( + BT, + model_config["hidden_size"], + model_config["vocab_size"], + dtype=dtype, + device=device, + ) config = fused_linear_cross_entropy_fwd.autotune(args) if dtype == torch.bfloat16: dtype_str = "bf16" elif dtype == torch.float32: dtype_str = "fp32" - config.save(f"configs/fused_linear_cross_entropy_fwd_{model_name}_{dtype_str}.json") - - def generate_flce_bwd_input(BT, V, H, dtype): - x = torch.randn(BT, H, device=device, dtype=dtype) - weight = torch.randn(V, H, device=device, dtype=dtype) - target = torch.randint(0, V, (BT,), device=device) - lse = torch.randn(BT, device=device, dtype=torch.float32) - return (x, weight, target, lse) + config.save(f"configs/fused_linear_cross_entropy_fwd_{gpu_name}_{model_name}_{dtype_str}.json") for model_name, model_config in model_config_dataset.items(): for dtype in [torch.bfloat16, torch.float32]: BT = 4096 - args = generate_flce_bwd_input(BT, model_config["hidden_size"], model_config["vocab_size"]) + args = generate_flce_bwd_input( + BT, + model_config["hidden_size"], + model_config["vocab_size"], + dtype=dtype, + device=device, + ) config = fused_linear_cross_entropy_bwd.autotune(args) if dtype == torch.bfloat16: dtype_str = "bf16" elif dtype == torch.float32: dtype_str = "fp32" - config.save(f"configs/fused_linear_cross_entropy_bwd_{model_name}_{dtype_str}.json") - - def generate_grad_logits_compute_input(BT, V, H, dtype): - x = torch.randn(BT, H, device=device, dtype=dtype) - weight = torch.randn(V, H, device=device, dtype=dtype) - target = torch.randint(0, V, (BT,), device=device) - lse = torch.randn(BT, device=device, dtype=torch.float32) - n_non_ignore = (target != ignore_index).sum().unsqueeze(0) - return (x, weight, target, lse, n_non_ignore) + config.save(f"configs/fused_linear_cross_entropy_bwd_{gpu_name}_{model_name}_{dtype_str}.json") for model_name, model_config in model_config_dataset.items(): for dtype in [torch.bfloat16, torch.float32]: BT = 4096 - args = generate_grad_logits_compute_input(BT, model_config["hidden_size"], model_config["vocab_size"]) + args = generate_grad_logits_compute_input( + BT, + model_config["hidden_size"], + model_config["vocab_size"], + dtype=dtype, + device=device, + ) config = _grad_logit_compute.autotune(args) if dtype == torch.bfloat16: dtype_str = "bf16" elif dtype == torch.float32: dtype_str = "fp32" - config.save(f"configs/_grad_logit_compute_{model_name}_{dtype_str}.json") - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--autotune', default=False) - args = parser.parse_args() + config.save(f"configs/_grad_logit_compute_{gpu_name}_{model_name}_{dtype_str}.json") - torch.manual_seed(0) - torch.cuda.manual_seed(0) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - torch.backends.cudnn.enabled = True - from liger_kernel.utils import infer_device +def check(): device = infer_device() - torch_device = getattr(torch, device) - gpu_name = torch_device.get_device_name(torch_device.current_device()) - # batch_size = 2 - # seq_len = 4096 - # hidden_size = 2304 - # vocab_size = 262208 batch_size = 2 - seq_len = 1024 - hidden_size = 512 - vocab_size = 1024 + seq_len = 4096 + hidden_size = 4096 + vocab_size = 32000 + print(f"BT={batch_size * seq_len}, H={hidden_size}, V={vocab_size}") @@ -528,27 +537,6 @@ def generate_grad_logits_compute_input(BT, V, H, dtype): input = torch.randn(batch_size * seq_len, hidden_size, device=device, requires_grad=True) weight = torch.randn(vocab_size, hidden_size, device=device, requires_grad=True) target = torch.randint(0, vocab_size, (batch_size * seq_len,), device=device) - - model_config_dataset = { - "llama": { - "hidden_size": 4096, - "vocab_size": 32000, - }, - "gemma3": { - "hidden_size": 2305, - "vocab_size": 262208, - }, - "qwen3": { - "hidden_size": 4096, - "vocab_size": 151936, - }, - } - - if args.autotune: - print("autotuning all kernels...") - autotune_kernels(model_config_dataset=model_config_dataset) - - # Init ref_lm_head_ce = TorchLMHeadCE(hidden_size, vocab_size, dtype=dtype, reduction=reduction).to(device=device) liger_lm_head_ce = LigerLMHeadCE(hidden_size, vocab_size, dtype=dtype, reduction=reduction, bwd_impl="cce").to( @@ -608,3 +596,38 @@ def fwd_bwd_fn(input, target, fn): rtol=rtol, atol=atol, ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--autotune", default=False) + parser.add_argument("--benchmark", default=True) + args = parser.parse_args() + + torch.manual_seed(0) + torch.cuda.manual_seed(0) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.enabled = True + + if args.benchmark: + check() + + model_config_dataset = { + "llama": { + "hidden_size": 4096, + "vocab_size": 32000, + }, + "gemma3": { + "hidden_size": 2305, + "vocab_size": 262208, + }, + "qwen3": { + "hidden_size": 4096, + "vocab_size": 151936, + }, + } + + if args.autotune: + print("autotuning all kernels...") + autotune_kernels(model_config_dataset=model_config_dataset) From 07deb98e459a15da18cedccc8b447931389ade96 Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Sat, 8 Nov 2025 12:43:56 +0000 Subject: [PATCH 25/27] Fix autotune fuction Signed-off-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> --- .../ops/helion/fused_linear_cross_entropy.py | 46 +++++++++++++++---- 1 file changed, 38 insertions(+), 8 deletions(-) diff --git a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py index b3c1b6e56..fb7c3588b 100644 --- a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py @@ -459,7 +459,7 @@ def generate_grad_logits_compute_input(BT, V, H, dtype, device): n_non_ignore = (target != -100).sum().unsqueeze(0) return (x, weight, target, lse, n_non_ignore) - +from helion.autotuner import PatternSearch def autotune_kernels(model_config_dataset): device = infer_device() torch_device = getattr(torch, device) @@ -475,7 +475,16 @@ def autotune_kernels(model_config_dataset): dtype=dtype, device=device, ) - config = fused_linear_cross_entropy_fwd.autotune(args) + bound = fused_linear_cross_entropy_fwd.bind(args) + tuner = PatternSearch( + bound, + args, + # Double the defaults to explore more candidates: + initial_population=100, # Default is 100. + copies=5, # Default is 5. + max_generations=10, # Default is 20. + ) + config = tuner.autotune() if dtype == torch.bfloat16: dtype_str = "bf16" elif dtype == torch.float32: @@ -492,7 +501,16 @@ def autotune_kernels(model_config_dataset): dtype=dtype, device=device, ) - config = fused_linear_cross_entropy_bwd.autotune(args) + bound = fused_linear_cross_entropy_bwd.bind(args) + tuner = PatternSearch( + bound, + args, + # Double the defaults to explore more candidates: + initial_population=100, # Default is 100. + copies=5, # Default is 5. + max_generations=10, # Default is 20. + ) + config = tuner.autotune() if dtype == torch.bfloat16: dtype_str = "bf16" elif dtype == torch.float32: @@ -509,7 +527,16 @@ def autotune_kernels(model_config_dataset): dtype=dtype, device=device, ) - config = _grad_logit_compute.autotune(args) + bound = _grad_logit_compute.bind(args) + tuner = PatternSearch( + bound, + args, + # Double the defaults to explore more candidates: + initial_population=100, # Default is 100. + copies=5, # Default is 5. + max_generations=10, # Default is 20. + ) + config = tuner.autotune() if dtype == torch.bfloat16: dtype_str = "bf16" elif dtype == torch.float32: @@ -600,8 +627,8 @@ def fwd_bwd_fn(input, target, fn): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--autotune", default=False) - parser.add_argument("--benchmark", default=True) + parser.add_argument("--autotune", action="store_true") + parser.add_argument("--benchmark", action="store_true") args = parser.parse_args() torch.manual_seed(0) @@ -610,8 +637,6 @@ def fwd_bwd_fn(input, target, fn): torch.backends.cudnn.benchmark = False torch.backends.cudnn.enabled = True - if args.benchmark: - check() model_config_dataset = { "llama": { @@ -631,3 +656,8 @@ def fwd_bwd_fn(input, target, fn): if args.autotune: print("autotuning all kernels...") autotune_kernels(model_config_dataset=model_config_dataset) + + if args.benchmark: + print("test correctness and benchmark all implementations") + check() + From 06bde208e5a70d8bf60786423f2ad2c3a81fbdc2 Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Sat, 8 Nov 2025 20:06:42 +0000 Subject: [PATCH 26/27] Add h100 autotune configs Signed-off-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> --- ...ear_cross_entropy_fwd_h100_llama_bf16.json | 45 +++++ ...ear_cross_entropy_fwd_h100_llama_fp32.json | 45 +++++ ...y_grad_logits_compute_h100_llama_fp32.json | 50 +++++ .../ops/helion/fused_linear_cross_entropy.py | 177 ++++++++++-------- 4 files changed, 236 insertions(+), 81 deletions(-) create mode 100644 src/liger_kernel/ops/helion/configs/fused_linear_cross_entropy_fwd_h100_llama_bf16.json create mode 100644 src/liger_kernel/ops/helion/configs/fused_linear_cross_entropy_fwd_h100_llama_fp32.json create mode 100644 src/liger_kernel/ops/helion/configs/fused_linear_cross_entropy_grad_logits_compute_h100_llama_fp32.json diff --git a/src/liger_kernel/ops/helion/configs/fused_linear_cross_entropy_fwd_h100_llama_bf16.json b/src/liger_kernel/ops/helion/configs/fused_linear_cross_entropy_fwd_h100_llama_bf16.json new file mode 100644 index 000000000..196aa1ab8 --- /dev/null +++ b/src/liger_kernel/ops/helion/configs/fused_linear_cross_entropy_fwd_h100_llama_bf16.json @@ -0,0 +1,45 @@ +{ + "block_sizes": [ + 64, + 64, + 256 + ], + "range_unroll_factors": [ + 0, + 1, + 1 + ], + "range_num_stages": [ + 0, + 3, + 4 + ], + "range_multi_buffers": [ + null, + false, + null + ], + "range_flattens": [ + null, + true, + true + ], + "load_eviction_policies": [ + "last", + "last", + "", + "" + ], + "num_warps": 4, + "num_stages": 8, + "indexing": [ + "tensor_descriptor", + "pointer", + "tensor_descriptor", + "tensor_descriptor", + "pointer", + "pointer" + ], + "pid_type": "flat", + "range_warp_specializes": [] +} \ No newline at end of file diff --git a/src/liger_kernel/ops/helion/configs/fused_linear_cross_entropy_fwd_h100_llama_fp32.json b/src/liger_kernel/ops/helion/configs/fused_linear_cross_entropy_fwd_h100_llama_fp32.json new file mode 100644 index 000000000..d19ea15e3 --- /dev/null +++ b/src/liger_kernel/ops/helion/configs/fused_linear_cross_entropy_fwd_h100_llama_fp32.json @@ -0,0 +1,45 @@ +{ + "block_sizes": [ + 64, + 32, + 256 + ], + "range_unroll_factors": [ + 0, + 1, + 1 + ], + "range_num_stages": [ + 0, + 3, + 4 + ], + "range_multi_buffers": [ + null, + true, + null + ], + "range_flattens": [ + null, + null, + true + ], + "load_eviction_policies": [ + "last", + "last", + "", + "" + ], + "num_warps": 4, + "num_stages": 6, + "indexing": [ + "tensor_descriptor", + "tensor_descriptor", + "pointer", + "tensor_descriptor", + "pointer", + "tensor_descriptor" + ], + "pid_type": "flat", + "range_warp_specializes": [] +} \ No newline at end of file diff --git a/src/liger_kernel/ops/helion/configs/fused_linear_cross_entropy_grad_logits_compute_h100_llama_fp32.json b/src/liger_kernel/ops/helion/configs/fused_linear_cross_entropy_grad_logits_compute_h100_llama_fp32.json new file mode 100644 index 000000000..afb5b1dc1 --- /dev/null +++ b/src/liger_kernel/ops/helion/configs/fused_linear_cross_entropy_grad_logits_compute_h100_llama_fp32.json @@ -0,0 +1,50 @@ +{ + "block_sizes": [ + 64, + 32, + 256 + ], + "loop_orders": [ + [ + 0, + 1 + ] + ], + "l2_groupings": [ + 32 + ], + "range_unroll_factors": [ + 0, + 1 + ], + "range_num_stages": [ + 4, + 2 + ], + "range_multi_buffers": [ + true, + null + ], + "range_flattens": [ + true, + true + ], + "load_eviction_policies": [ + "last", + "last", + "first", + "first" + ], + "num_warps": 8, + "num_stages": 1, + "indexing": [ + "tensor_descriptor", + "tensor_descriptor", + "tensor_descriptor", + "tensor_descriptor", + "pointer", + "tensor_descriptor" + ], + "pid_type": "persistent_interleaved", + "range_warp_specializes": [] +} \ No newline at end of file diff --git a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py index fb7c3588b..5e1cae758 100644 --- a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py @@ -1,4 +1,5 @@ import argparse +from pathlib import Path import helion import helion.language as hl @@ -8,24 +9,13 @@ from liger_kernel.utils import infer_device -# Best config for default llama3Config(hidden_size=4096, vocab_size=32000) with 4096 bs*seqlen input -h100_fwd_config = helion.Config( - block_sizes=[64, 32, 512], - indexing=["pointer", "pointer", "tensor_descriptor", "pointer", "tensor_descriptor", "tensor_descriptor"], - load_eviction_policies=["", "last", "last", "last"], - num_stages=8, - num_warps=16, - pid_type="flat", - range_flattens=[None, False, None], - range_multi_buffers=[None, True, False], - range_num_stages=[0, 3, 3], - range_unroll_factors=[0, 0, 1], - range_warp_specializes=[], -) +CONFIG_PATH_STR = str(Path(__file__).parent.joinpath("configs", "fused_linear_cross_entropy")) -# @helion.kernel(config=h100_fwd_config, static_shapes=True) -@helion.kernel(autotune_effort="none", ignore_warnings=[helion.exc.TensorOperationInWrapper]) +# Best config for default llama3Config(hidden_size=4096, vocab_size=32000) with 4096 bs*seqlen input +h100_fwd_config = helion.Config.load(CONFIG_PATH_STR + "_fwd_h100_llama_fp32.json") +@helion.kernel(config=h100_fwd_config, static_shapes=True) +# @helion.kernel(autotune_effort="none", autotune_compile_timeout=20, ignore_warnings=[helion.exc.TensorOperationInWrapper]) def fused_linear_cross_entropy_fwd( x: torch.Tensor, weight: torch.Tensor, @@ -138,7 +128,7 @@ def fused_linear_cross_entropy_fwd( # @helion.kernel(config=h100_bwd_config, static_shapes=True) -@helion.kernel(autotune_effort="none", ignore_warnings=[helion.exc.TensorOperationInWrapper]) +@helion.kernel(autotune_effort="none", autotune_compile_timeout=20, ignore_warnings=[helion.exc.TensorOperationInWrapper]) def fused_linear_cross_entropy_bwd( x: torch.Tensor, weight: torch.Tensor, @@ -155,6 +145,7 @@ def fused_linear_cross_entropy_bwd( grad_x = torch.zeros_like(x, dtype=torch.float32) grad_w = torch.zeros_like(weight, dtype=torch.float32) n_non_ignore = (target != ignore_index).sum().unsqueeze(0) + assert n_non_ignore != 0, "All targets are ignored." num_block_bt = (BT + block_size_bt - 1) // block_size_bt num_block_h = (H + block_size_h - 1) // block_size_h @@ -186,15 +177,15 @@ def fused_linear_cross_entropy_bwd( grad_logits_tile = grad_logits_tile * ((tile_bt.index < BT)[:, None] & (tile_v.index < V)[None, :]) # handle ignore index - grad_logits_tile = grad_logits_tile * (target_indices != ignore_index) + # grad_logits_tile = grad_logits_tile * (target_indices != ignore_index) if reduction == "mean": grad_logits_tile /= n_non_ignore_value for tile_h in hl.tile(H, block_size=block_size_h): # grad_x = grad_logits @ weight - rhs_tile = weight[tile_v, tile_h] - partial_grad_x = hl.dot(grad_logits_tile, rhs_tile, out_dtype=torch.float32) + rhs_tile_1 = weight[tile_v, tile_h] + partial_grad_x = hl.dot(grad_logits_tile, rhs_tile_1, out_dtype=torch.float32) while hl.atomic_cas(grad_x_lock, [tile_bt.id, tile_h.id], 0, 1, sem="acquire") == 1: pass grad_x[tile_bt, tile_h] += partial_grad_x @@ -203,8 +194,8 @@ def fused_linear_cross_entropy_bwd( # for tile_h in hl.tile(H, block_size=block_size_h): # grad_w = grad_logits.T[tile_v, tile_bt] @ x[tile_bt, tile_h] - rhs_tile = x[tile_bt, tile_h] - partial_grad_w = hl.dot(grad_logits_tile.T, rhs_tile, out_dtype=torch.float32) + rhs_tile_2 = x[tile_bt, tile_h] + partial_grad_w = hl.dot(grad_logits_tile.T, rhs_tile_2, out_dtype=torch.float32) while hl.atomic_cas(grad_w_lock, [tile_v.id, tile_h.id], 0, 1, sem="acquire") == 1: pass grad_w[tile_v, tile_h] += partial_grad_w @@ -213,8 +204,9 @@ def fused_linear_cross_entropy_bwd( return grad_x.to(x.dtype), grad_w.to(x.dtype) - -@helion.kernel(autotune_effort="none", ignore_warnings=[helion.exc.TensorOperationInWrapper]) +h100_grad_logit_compute_config = helion.Config.load(CONFIG_PATH_STR + "_grad_logits_compute_h100_llama_fp32.json") +@helion.kernel(config=h100_grad_logit_compute_config, static_shapes=True) +# @helion.kernel(autotune_effort="none", autotune_compile_timeout=20, ignore_warnings=[helion.exc.TensorOperationInWrapper]) def _grad_logit_compute( x: torch.Tensor, weight: torch.Tensor, @@ -255,13 +247,14 @@ def _grad_logit_compute( grad_logits_tile = grad_logits_tile * ((tile_bt.index < BT)[:, None] & (tile_v.index < V)[None, :]) # handle ignore index - grad_logits_tile = grad_logits_tile * (target_indices != ignore_index) + # grad_logits_tile = grad_logits_tile * (target_indices != ignore_index) if reduction == "mean": grad_logits_tile /= n_non_ignore_value grad_logits[tile_bt, tile_v] = grad_logits_tile - return grad_logits + + return grad_logits.to(x.dtype) def fused_linear_cross_entropy_bwd_chunk( @@ -447,7 +440,7 @@ def generate_flce_bwd_input(BT, V, H, dtype, device): x = torch.randn(BT, H, device=device, dtype=dtype) weight = torch.randn(V, H, device=device, dtype=dtype) target = torch.randint(0, V, (BT,), device=device) - lse = torch.randn(BT, device=device, dtype=torch.float32) + lse = torch.randn(BT, device=device, dtype=torch.float32) + 1.0 return (x, weight, target, lse) @@ -455,19 +448,40 @@ def generate_grad_logits_compute_input(BT, V, H, dtype, device): x = torch.randn(BT, H, device=device, dtype=dtype) weight = torch.randn(V, H, device=device, dtype=dtype) target = torch.randint(0, V, (BT,), device=device) - lse = torch.randn(BT, device=device, dtype=torch.float32) + lse = torch.randn(BT, device=device, dtype=torch.float32) + 1.0 n_non_ignore = (target != -100).sum().unsqueeze(0) return (x, weight, target, lse, n_non_ignore) +from pathlib import Path + from helion.autotuner import PatternSearch def autotune_kernels(model_config_dataset): device = infer_device() torch_device = getattr(torch, device) gpu_name = torch_device.get_device_name(torch_device.current_device()) - + if "h100" in gpu_name.lower(): + gpu_name = "h100" + elif "a100" in gpu_name.lower(): + gpu_name = "a100" + elif "b200" in gpu_name.lower(): + gpu_name = "b200" + + # bf16 has nan issue + # dtypes = [torch.bfloat16, torch.float32] + dtypes = [torch.float32] + for model_name, model_config in model_config_dataset.items(): - for dtype in [torch.bfloat16, torch.float32]: + + for dtype in dtypes: BT = 4096 + if dtype == torch.bfloat16: + dtype_str = "bf16" + elif dtype == torch.float32: + dtype_str = "fp32" + file = Path(f"{CONFIG_PATH_STR}_fwd_{gpu_name}_{model_name}_{dtype_str}.json") + if file.is_file(): + print(f"File exists at {str(file)} . Skip autotuning") + continue args = generate_flce_fwd_input( BT, model_config["hidden_size"], @@ -479,47 +493,54 @@ def autotune_kernels(model_config_dataset): tuner = PatternSearch( bound, args, - # Double the defaults to explore more candidates: - initial_population=100, # Default is 100. + initial_population=50, # Default is 100. copies=5, # Default is 5. - max_generations=10, # Default is 20. + max_generations=15, # Default is 20. ) config = tuner.autotune() - if dtype == torch.bfloat16: - dtype_str = "bf16" - elif dtype == torch.float32: - dtype_str = "fp32" - config.save(f"configs/fused_linear_cross_entropy_fwd_{gpu_name}_{model_name}_{dtype_str}.json") + config.save(f"{CONFIG_PATH_STR}_fwd_{gpu_name}_{model_name}_{dtype_str}.json") + # nan if shapes are not divisible (out of bound values?) + # for model_name, model_config in model_config_dataset.items(): + # for dtype in dtypes: + # BT = 4096 + # if dtype == torch.bfloat16: + # dtype_str = "bf16" + # elif dtype == torch.float32: + # dtype_str = "fp32" + # file = Path(f"{CONFIG_PATH_STR}_bwd_{gpu_name}_{model_name}_{dtype_str}.json") + # if file.is_file(): + # print(f"File exists at {str(file)}. Skip autotuning") + # continue + # args = generate_flce_bwd_input( + # BT, + # model_config["hidden_size"], + # model_config["vocab_size"], + # dtype=dtype, + # device=device, + # ) + # bound = fused_linear_cross_entropy_bwd.bind(args) + # tuner = PatternSearch( + # bound, + # args, + # initial_population=50, # Default is 100. + # copies=5, # Default is 5. + # max_generations=15, # Default is 20. + # ) + # config = tuner.autotune() + + # config.save(f"{CONFIG_PATH_STR}_bwd_{gpu_name}_{model_name}_{dtype_str}.json") for model_name, model_config in model_config_dataset.items(): - for dtype in [torch.bfloat16, torch.float32]: + for dtype in dtypes: BT = 4096 - args = generate_flce_bwd_input( - BT, - model_config["hidden_size"], - model_config["vocab_size"], - dtype=dtype, - device=device, - ) - bound = fused_linear_cross_entropy_bwd.bind(args) - tuner = PatternSearch( - bound, - args, - # Double the defaults to explore more candidates: - initial_population=100, # Default is 100. - copies=5, # Default is 5. - max_generations=10, # Default is 20. - ) - config = tuner.autotune() if dtype == torch.bfloat16: dtype_str = "bf16" elif dtype == torch.float32: dtype_str = "fp32" - config.save(f"configs/fused_linear_cross_entropy_bwd_{gpu_name}_{model_name}_{dtype_str}.json") - - for model_name, model_config in model_config_dataset.items(): - for dtype in [torch.bfloat16, torch.float32]: - BT = 4096 + file = Path(f"{CONFIG_PATH_STR}_grad_logits_compute_{gpu_name}_{model_name}_{dtype_str}.json") + if file.is_file(): + print(f"File exists at {str(file)}. Skip autotuning") + continue args = generate_grad_logits_compute_input( BT, model_config["hidden_size"], @@ -531,24 +552,19 @@ def autotune_kernels(model_config_dataset): tuner = PatternSearch( bound, args, - # Double the defaults to explore more candidates: - initial_population=100, # Default is 100. + initial_population=50, # Default is 100. copies=5, # Default is 5. - max_generations=10, # Default is 20. + max_generations=15, # Default is 20. ) config = tuner.autotune() - if dtype == torch.bfloat16: - dtype_str = "bf16" - elif dtype == torch.float32: - dtype_str = "fp32" - config.save(f"configs/_grad_logit_compute_{gpu_name}_{model_name}_{dtype_str}.json") + config.save(f"{CONFIG_PATH_STR}_grad_logits_compute_{gpu_name}_{model_name}_{dtype_str}.json") def check(): device = infer_device() batch_size = 2 - seq_len = 4096 + seq_len = 2048 hidden_size = 4096 vocab_size = 32000 @@ -586,7 +602,7 @@ def check(): def fwd_bwd_fn(input, target, fn): loss = fn(input, target) loss.backward() - return loss + return input.grad liger_lm_head_ce_fwd_bwd = partial(fwd_bwd_fn, fn=liger_lm_head_ce) liger_chunk_lm_head_ce_fwd_bwd = partial(fwd_bwd_fn, fn=liger_chunk_lm_head_ce) @@ -611,7 +627,7 @@ def fwd_bwd_fn(input, target, fn): if reduction != "none": run_example( { - "helion_fwd_bwd_cce": liger_lm_head_ce_fwd_bwd, + # "helion_fwd_bwd_cce": liger_lm_head_ce_fwd_bwd, # nan "helion_fwd_bwd_chunk": liger_chunk_lm_head_ce_fwd_bwd, }, { @@ -637,20 +653,19 @@ def fwd_bwd_fn(input, target, fn): torch.backends.cudnn.benchmark = False torch.backends.cudnn.enabled = True - model_config_dataset = { "llama": { "hidden_size": 4096, "vocab_size": 32000, }, - "gemma3": { - "hidden_size": 2305, - "vocab_size": 262208, - }, - "qwen3": { - "hidden_size": 4096, - "vocab_size": 151936, - }, + # "gemma3": { + # "hidden_size": 2305, + # "vocab_size": 262208, + # }, + # "qwen3": { + # "hidden_size": 4096, + # "vocab_size": 151936, + # }, } if args.autotune: From 4eb0a69084488e86e1157bbeb7a96bd2055538fd Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Tue, 11 Nov 2025 16:03:22 +0000 Subject: [PATCH 27/27] Fix reduction!="mean" and add benchmark script --- .../benchmark_fused_linear_cross_entropy.py | 25 +- ...nd_grad_logit_compute_h100_llama_fp32.json | 49 +++ .../ops/helion/fused_linear_cross_entropy.py | 333 +++++++++++++++--- .../helion/test_fused_linear_cross_entropy.py | 18 +- 4 files changed, 375 insertions(+), 50 deletions(-) create mode 100644 src/liger_kernel/ops/helion/configs/fused_linear_cross_entropy_nll_and_grad_logit_compute_h100_llama_fp32.json diff --git a/benchmark/scripts/benchmark_fused_linear_cross_entropy.py b/benchmark/scripts/benchmark_fused_linear_cross_entropy.py index 4d36a66a6..261bd2a1d 100644 --- a/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +++ b/benchmark/scripts/benchmark_fused_linear_cross_entropy.py @@ -8,6 +8,7 @@ from utils import parse_benchmark_script_args from utils import run_benchmarks +from liger_kernel.ops.helion.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyHelion from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss from liger_kernel.utils import infer_device @@ -45,6 +46,20 @@ def forward(self, x, y): return self.ce_loss(self.lin.weight, x, y) +class LigerLMHeadCEHelion(torch.nn.Module): + def __init__( + self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100, bwd_impl="chunk", grad_in_forward=False + ): + super().__init__() + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype) + self.ce_loss = LigerFusedLinearCrossEntropyHelion( + ignore_index=ignore_index, reduction="mean", bwd_impl=bwd_impl, grad_in_forward=grad_in_forward + ) + + def forward(self, x, y): + return self.ce_loss(x, self.lin.weight, y) + + ############################################################################# # Test the memory consumption of the linear fused cross entropy loss ############################################################################# @@ -64,6 +79,10 @@ def bench_memory_fused_linear_cross_entropy( lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device) elif provider == "liger-fp32-accum": lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype, accum_dtype=torch.float32).to(device) + elif provider == "liger-helion": + lm_head_ce = LigerLMHeadCEHelion(H=H, V=V, dtype=dtype, bwd_impl="chunk", grad_in_forward=False).to(device) + elif provider == "liger-helion-grad-in-fwd": + lm_head_ce = LigerLMHeadCEHelion(H=H, V=V, dtype=dtype, bwd_impl="chunk", grad_in_forward=True).to(device) else: lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device) @@ -106,6 +125,10 @@ def bench_speed_fused_linear_cross_entropy( lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device) elif provider == "liger-fp32-accum": lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype, accum_dtype=torch.float32).to(device) + elif provider == "liger-helion": + lm_head_ce = LigerLMHeadCEHelion(H=H, V=V, dtype=dtype, bwd_impl="chunk", grad_in_forward=False).to(device) + elif provider == "liger-helion-grad-in-fwd": + lm_head_ce = LigerLMHeadCEHelion(H=H, V=V, dtype=dtype, bwd_impl="chunk", grad_in_forward=True).to(device) else: lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device) @@ -163,7 +186,7 @@ def full(): "x_name": "BT", "x_label": "B x T", "x_values": [2**i for i in range(12, 16)], - "kernel_providers": ["liger", "liger-fp32-accum", "huggingface"], + "kernel_providers": ["liger", "liger-fp32-accum", "huggingface", "liger-helion", "liger-helion-grad-in-fwd"], "extra_benchmark_configs": [{"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16}], "overwrite": args.overwrite, } diff --git a/src/liger_kernel/ops/helion/configs/fused_linear_cross_entropy_nll_and_grad_logit_compute_h100_llama_fp32.json b/src/liger_kernel/ops/helion/configs/fused_linear_cross_entropy_nll_and_grad_logit_compute_h100_llama_fp32.json new file mode 100644 index 000000000..3fe8fc1a6 --- /dev/null +++ b/src/liger_kernel/ops/helion/configs/fused_linear_cross_entropy_nll_and_grad_logit_compute_h100_llama_fp32.json @@ -0,0 +1,49 @@ +{ + "block_sizes": [ + 64, + 32, + 512 + ], + "range_unroll_factors": [ + 3, + 0, + 0 + ], + "range_num_stages": [ + 4, + 0, + 3 + ], + "range_multi_buffers": [ + true, + true, + false + ], + "range_flattens": [ + true, + true, + false + ], + "load_eviction_policies": [ + "last", + "first", + "", + "last", + "", + "first" + ], + "num_warps": 8, + "num_stages": 7, + "indexing": [ + "pointer", + "tensor_descriptor", + "pointer", + "pointer", + "pointer", + "pointer", + "pointer", + "pointer" + ], + "pid_type": "persistent_blocked", + "range_warp_specializes": [] +} \ No newline at end of file diff --git a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py index 5e1cae758..e964500d3 100644 --- a/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/helion/fused_linear_cross_entropy.py @@ -1,4 +1,5 @@ import argparse + from pathlib import Path import helion @@ -14,7 +15,9 @@ # Best config for default llama3Config(hidden_size=4096, vocab_size=32000) with 4096 bs*seqlen input h100_fwd_config = helion.Config.load(CONFIG_PATH_STR + "_fwd_h100_llama_fp32.json") -@helion.kernel(config=h100_fwd_config, static_shapes=True) + + +@helion.kernel(config=h100_fwd_config, static_shapes=True, ignore_warnings=[helion.exc.TensorOperationInWrapper]) # @helion.kernel(autotune_effort="none", autotune_compile_timeout=20, ignore_warnings=[helion.exc.TensorOperationInWrapper]) def fused_linear_cross_entropy_fwd( x: torch.Tensor, @@ -128,7 +131,9 @@ def fused_linear_cross_entropy_fwd( # @helion.kernel(config=h100_bwd_config, static_shapes=True) -@helion.kernel(autotune_effort="none", autotune_compile_timeout=20, ignore_warnings=[helion.exc.TensorOperationInWrapper]) +@helion.kernel( + autotune_effort="none", autotune_compile_timeout=20, ignore_warnings=[helion.exc.TensorOperationInWrapper] +) def fused_linear_cross_entropy_bwd( x: torch.Tensor, weight: torch.Tensor, @@ -204,8 +209,13 @@ def fused_linear_cross_entropy_bwd( return grad_x.to(x.dtype), grad_w.to(x.dtype) + h100_grad_logit_compute_config = helion.Config.load(CONFIG_PATH_STR + "_grad_logits_compute_h100_llama_fp32.json") -@helion.kernel(config=h100_grad_logit_compute_config, static_shapes=True) + + +@helion.kernel( + config=h100_grad_logit_compute_config, static_shapes=True, ignore_warnings=[helion.exc.TensorOperationInWrapper] +) # @helion.kernel(autotune_effort="none", autotune_compile_timeout=20, ignore_warnings=[helion.exc.TensorOperationInWrapper]) def _grad_logit_compute( x: torch.Tensor, @@ -253,7 +263,7 @@ def _grad_logit_compute( grad_logits_tile /= n_non_ignore_value grad_logits[tile_bt, tile_v] = grad_logits_tile - + return grad_logits.to(x.dtype) @@ -299,52 +309,239 @@ def fused_linear_cross_entropy_bwd_chunk( return grad_x.to(x.dtype), grad_w.to(x.dtype) -class LigerFusedLinearCrossEntropyHelionFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, _input, weight, target, ignore_index=-100, reduction="mean", bwd_impl="chunk"): - assert bwd_impl in ["chunk", "cce"] - loss, lse = fused_linear_cross_entropy_fwd( - _input, +h100_nll_and_grad_logit_compute_config = helion.Config.load( + CONFIG_PATH_STR + "_nll_and_grad_logit_compute_h100_llama_fp32.json" +) + + +@helion.kernel( + config=h100_nll_and_grad_logit_compute_config, + static_shapes=False, + ignore_warnings=[helion.exc.TensorOperationInWrapper], +) +# @helion.kernel(autotune_effort="none", autotune_compile_timeout=20, ignore_warnings=[helion.exc.TensorOperationInWrapper]) +def _nll_and_grad_logit_compute( + x: torch.Tensor, + weight: torch.Tensor, + target: torch.Tensor, + n_non_ignore: torch.Tensor, + ignore_index: int = -100, + reduction: str = "mean", +): + BT, H = x.size() + V = weight.size(0) + + block_size_bt = hl.register_block_size(BT) + block_size_h = hl.register_block_size(H) + block_size_v = hl.register_block_size(V) + grad_logits = torch.zeros((BT, V), dtype=torch.float32, device=x.device) + nll = torch.zeros(BT, dtype=torch.float32, device=x.device) + + for tile_bt in hl.tile(BT, block_size=block_size_bt): + m_i = hl.zeros([tile_bt], dtype=torch.float32) - float("inf") + d_i = hl.zeros([tile_bt], dtype=torch.float32) + nll_tile = hl.zeros([tile_bt], dtype=torch.float32) + if reduction == "mean": + n_non_ignore_value = hl.load(n_non_ignore, [0]) + + # target_indices = target[tile_bt][:, None] # ERROR: it introduces a new size, which is not broadcastable + target_indices = target[tile_bt].unsqueeze(1) # [tile_bt, 1] + + # statistics pass + for tile_v in hl.tile(V, block_size=block_size_v): + # logits computation + acc = hl.zeros([tile_bt, tile_v], dtype=torch.float32) + for tile_h in hl.tile(H, block_size=block_size_h): + x_tile = x[tile_bt, tile_h] + weight_tile = weight[tile_v, tile_h] + acc = hl.dot(x_tile, weight_tile.T, acc=acc, out_dtype=torch.float32) + + # online softmax statistics + m_ij = torch.maximum(m_i, torch.amax(acc, dim=-1)) + d_i = d_i * torch.exp(m_i - m_ij) + torch.exp(acc - m_ij[:, None]).sum(dim=-1) + m_i = m_ij + + # offset = tile_v.index[None, :] # ERROR: it introduces a new size, which is not broadcastable + offset = tile_v.index.unsqueeze(0) # [1, tile_v] + mask = target_indices == offset # [tile_bt, tile_v] + nll_tile += torch.sum(-acc * mask, dim=-1) # [tile_bt] + + # loss computation: -logsoftmax(x_y) = -log(exp(x_y) / sum(exp(x_i))) = -x_y + log(sum(exp(x_i))) + lse_tile = m_i + torch.log(d_i) + nll_tile = nll_tile + lse_tile + + # handle ignore index + nll_tile = nll_tile * (target_indices.ravel() != ignore_index) + + if reduction == "mean": + nll_tile /= n_non_ignore_value + + nll[tile_bt] = nll_tile + + # gradients pass + for tile_v in hl.tile(V, block_size=block_size_v): + # Restore logits + acc2 = hl.zeros([tile_bt, tile_v], dtype=torch.float32) + for tile_h in hl.tile(H, block_size=block_size_h): + x_tile = x[tile_bt, tile_h] + weight_tile = weight[tile_v, tile_h] + acc2 = hl.dot(x_tile, weight_tile.T, acc=acc2, out_dtype=torch.float32) + + # logsoftmax(x_i) = softmax(x_i) - 1, for i == target + # = softmax(x_i), otherwise + # softmax(x_i) = exp(x_i) / sum(exp(x_i)) + # = exp(x_i) / log(exp(sum(x_i))) + # = exp(x_i) / lse = exp(x_i - lse) + grad_logits_tile = torch.exp(acc2 - lse_tile[:, None]) + offset = tile_v.index.unsqueeze(0) # [1, tile_v] + mask = target_indices == offset # [tile_bt, tile_v] + # handle i == target + grad_logits_tile = grad_logits_tile - mask.float() + # handle out of bound values in grad_logits_tile + grad_logits_tile = grad_logits_tile * ((tile_bt.index < BT)[:, None] & (tile_v.index < V)[None, :]) + + # handle ignore index + grad_logits_tile = grad_logits_tile * (target_indices != ignore_index) + + if reduction == "mean": + grad_logits_tile /= n_non_ignore_value + + grad_logits[tile_bt, tile_v] = grad_logits_tile + + return nll, grad_logits.to(x.dtype) + + +def fused_linear_cross_entropy_fwd_bwd_chunk( + x: torch.Tensor, + weight: torch.Tensor, + target: torch.Tensor, + ignore_index: int = -100, + reduction: str = "mean", +): + """ + Performs matrix multiplication followed by cross entropy loss, and gradients are all computed + in forward pass. + Args: + x: input tensor of shape [BT, H] + weight: weight tensor of shape [V, H] + target: target tensor of shape [BT,] + ignore_index: index to ignore in the target + reduction: reduction to apply to the loss + Returns: + loss: loss tensor of shape [1] if reduction is "mean" or "sum", [BT] otherwise + """ + BT, H = x.size() + V = weight.size(0) + # for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be: + # inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor + # for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048 + num_chunks = (V + H - 1) // H + chunk_size = (BT + num_chunks - 1) // num_chunks + grad_x = torch.zeros_like(x, dtype=torch.float32) + grad_w = torch.zeros_like(weight, dtype=torch.float32) + if reduction == "mean": + n_non_ignore = (target != ignore_index).sum().unsqueeze(0) + else: + n_non_ignore = torch.ones(1, device=x.device, dtype=torch.int) + + nll = torch.zeros(BT, device=x.device, dtype=torch.float32) + + x_chunks = torch.chunk(x, chunks=num_chunks, dim=0) + target_chunks = torch.chunk(target, chunks=num_chunks, dim=0) + nll_chunks = torch.chunk(target, chunks=num_chunks, dim=0) + + for chunk_id, (x_chunk, target_chunk, nll_chunk) in enumerate(zip(x_chunks, target_chunks, nll_chunks)): + start_idx = chunk_id * chunk_size + end_idx = min((chunk_id + 1) * chunk_size, BT) + + nll_chunk, grad_logits_chunk = _nll_and_grad_logit_compute( + x_chunk, weight, - target, - ignore_index, + target_chunk, + n_non_ignore, reduction, ) + + grad_x[start_idx:end_idx] = grad_logits_chunk @ weight + grad_w += torch.mm(grad_logits_chunk.T, x_chunk).float() + + nll[start_idx:end_idx] = nll_chunk + + if reduction != "none": + loss = nll.sum() + else: + loss = nll + + print(f"{reduction=}") + return loss, grad_x.to(x.dtype), grad_w.to(x.dtype) + + +class LigerFusedLinearCrossEntropyHelionFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, _input, weight, target, ignore_index=-100, reduction="mean", bwd_impl="chunk", grad_in_forward=False + ): + assert bwd_impl in ["chunk", "cce"] + assert grad_in_forward in [True, False] + if grad_in_forward: + loss, grad_x, grad_w = fused_linear_cross_entropy_fwd_bwd_chunk( + _input, + weight, + target, + ignore_index, + reduction, + ) + ctx.save_for_backward(grad_x, grad_w) + else: + loss, lse = fused_linear_cross_entropy_fwd( + _input, + weight, + target, + ignore_index, + reduction, + ) + ctx.save_for_backward(_input, lse, weight, target) ctx.ignore_index = ignore_index ctx.reduction = reduction ctx.bwd_impl = bwd_impl - ctx.save_for_backward(_input, lse, weight, target) + ctx.grad_in_forward = grad_in_forward return loss @staticmethod def backward(ctx, grad_output): assert grad_output.ndim == 0, "token_scaling is not supported. grad_output must be a scalar" - _input, lse, weight, target = ctx.saved_tensors - if ctx.bwd_impl == "cce": - bwd_fn = fused_linear_cross_entropy_bwd - elif ctx.bwd_impl == "chunk": - bwd_fn = fused_linear_cross_entropy_bwd_chunk - grad_input, grad_weight = bwd_fn( - _input, - weight, - target, - lse, - ctx.ignore_index, - ctx.reduction, - ) - return grad_input * grad_output, grad_weight * grad_output, None, None, None, None + if ctx.grad_in_forward: + grad_input, grad_weight = ctx.saved_tensors + else: + _input, lse, weight, target = ctx.saved_tensors + if ctx.bwd_impl == "cce": + bwd_fn = fused_linear_cross_entropy_bwd + elif ctx.bwd_impl == "chunk": + bwd_fn = fused_linear_cross_entropy_bwd_chunk + grad_input, grad_weight = bwd_fn( + _input, + weight, + target, + lse, + ctx.ignore_index, + ctx.reduction, + ) + return grad_input * grad_output, grad_weight * grad_output, None, None, None, None, None class LigerFusedLinearCrossEntropyHelion(torch.nn.Module): - def __init__(self, ignore_index=-100, reduction="mean", bwd_impl="chunk"): + def __init__(self, ignore_index=-100, reduction="mean", bwd_impl="chunk", grad_in_forward=False): super().__init__() self.ignore_index = ignore_index self.reduction = reduction self.bwd_impl = bwd_impl + self.grad_in_forward = grad_in_forward def forward(self, _input, weight, target): + assert _input.device == weight.device, f"{_input.device=}, {weight.device=}" + assert _input.device == target.device, f"{_input.device=}, {target.device=}" return LigerFusedLinearCrossEntropyHelionFunction.apply( - _input, weight, target, self.ignore_index, self.reduction, self.bwd_impl + _input, weight, target, self.ignore_index, self.reduction, self.bwd_impl, self.grad_in_forward ) @@ -376,11 +573,12 @@ def __init__( ignore_index: int = -100, reduction: str = "mean", bwd_impl: str = "cce", + grad_in_forward: bool = False, ): super().__init__() self.lm_head = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype) self.flce = LigerFusedLinearCrossEntropyHelion( - ignore_index=ignore_index, reduction=reduction, bwd_impl=bwd_impl + ignore_index=ignore_index, reduction=reduction, bwd_impl=bwd_impl, grad_in_forward=grad_in_forward ) def forward(self, x, target): @@ -440,7 +638,7 @@ def generate_flce_bwd_input(BT, V, H, dtype, device): x = torch.randn(BT, H, device=device, dtype=dtype) weight = torch.randn(V, H, device=device, dtype=dtype) target = torch.randint(0, V, (BT,), device=device) - lse = torch.randn(BT, device=device, dtype=torch.float32) + 1.0 + lse = torch.logsumexp(x @ weight.T, dim=-1) return (x, weight, target, lse) @@ -448,13 +646,24 @@ def generate_grad_logits_compute_input(BT, V, H, dtype, device): x = torch.randn(BT, H, device=device, dtype=dtype) weight = torch.randn(V, H, device=device, dtype=dtype) target = torch.randint(0, V, (BT,), device=device) - lse = torch.randn(BT, device=device, dtype=torch.float32) + 1.0 + lse = torch.logsumexp(x @ weight.T, dim=-1) n_non_ignore = (target != -100).sum().unsqueeze(0) return (x, weight, target, lse, n_non_ignore) + +def generate_nll_and_grad_logit_compute_input(BT, V, H, dtype, device): + x = torch.randn(BT, H, device=device, dtype=dtype) + weight = torch.randn(V, H, device=device, dtype=dtype) + target = torch.randint(0, V, (BT,), device=device) + n_non_ignore = (target != -100).sum().unsqueeze(0) + return (x, weight, target, n_non_ignore) + + from pathlib import Path from helion.autotuner import PatternSearch + + def autotune_kernels(model_config_dataset): device = infer_device() torch_device = getattr(torch, device) @@ -469,9 +678,8 @@ def autotune_kernels(model_config_dataset): # bf16 has nan issue # dtypes = [torch.bfloat16, torch.float32] dtypes = [torch.float32] - + for model_name, model_config in model_config_dataset.items(): - for dtype in dtypes: BT = 4096 if dtype == torch.bfloat16: @@ -494,8 +702,8 @@ def autotune_kernels(model_config_dataset): bound, args, initial_population=50, # Default is 100. - copies=5, # Default is 5. - max_generations=15, # Default is 20. + copies=5, # Default is 5. + max_generations=15, # Default is 20. ) config = tuner.autotune() config.save(f"{CONFIG_PATH_STR}_fwd_{gpu_name}_{model_name}_{dtype_str}.json") @@ -527,7 +735,7 @@ def autotune_kernels(model_config_dataset): # max_generations=15, # Default is 20. # ) # config = tuner.autotune() - + # config.save(f"{CONFIG_PATH_STR}_bwd_{gpu_name}_{model_name}_{dtype_str}.json") for model_name, model_config in model_config_dataset.items(): @@ -547,28 +755,56 @@ def autotune_kernels(model_config_dataset): model_config["vocab_size"], dtype=dtype, device=device, - ) + ) # args = (x, weight, target, lse, n_non_ignore) bound = _grad_logit_compute.bind(args) tuner = PatternSearch( bound, args, initial_population=50, # Default is 100. - copies=5, # Default is 5. - max_generations=15, # Default is 20. + copies=5, # Default is 5. + max_generations=15, # Default is 20. ) config = tuner.autotune() config.save(f"{CONFIG_PATH_STR}_grad_logits_compute_{gpu_name}_{model_name}_{dtype_str}.json") + for model_name, model_config in model_config_dataset.items(): + for dtype in dtypes: + BT = 4096 + if dtype == torch.bfloat16: + dtype_str = "bf16" + elif dtype == torch.float32: + dtype_str = "fp32" + file = Path(f"{CONFIG_PATH_STR}_nll_and_grad_logit_compute_{gpu_name}_{model_name}_{dtype_str}.json") + if file.is_file(): + print(f"File exists at {str(file)}. Skip autotuning") + continue + args = generate_nll_and_grad_logit_compute_input( + BT, + model_config["hidden_size"], + model_config["vocab_size"], + dtype=dtype, + device=device, + ) # args = (x, weight, target, nll, n_non_ignore) + bound = _nll_and_grad_logit_compute.bind(args) + tuner = PatternSearch( + bound, + args, + initial_population=50, # Default is 100. + copies=5, # Default is 5. + max_generations=15, # Default is 20. + ) + config = tuner.autotune() + config.save(f"{CONFIG_PATH_STR}_nll_and_grad_logit_compute_{gpu_name}_{model_name}_{dtype_str}.json") + def check(): device = infer_device() batch_size = 2 - seq_len = 2048 + seq_len = 4096 hidden_size = 4096 vocab_size = 32000 - print(f"BT={batch_size * seq_len}, H={hidden_size}, V={vocab_size}") dtype = torch.float32 @@ -588,6 +824,9 @@ def check(): liger_chunk_lm_head_ce = LigerLMHeadCE( hidden_size, vocab_size, dtype=dtype, reduction=reduction, bwd_impl="chunk" ).to(device=device) + liger_lm_head_ce_grad_in_fwd = LigerLMHeadCE( + hidden_size, vocab_size, dtype=dtype, reduction=reduction, bwd_impl="chunk", grad_in_forward=True + ).to(device=device) cce_lm_head_ce = CutLMHeadCE(hidden_size, vocab_size, dtype=dtype, reduction=reduction).to(device=device) triton_liger_lm_head_ce = TritonLigerLMHeadCE(hidden_size, vocab_size, dtype=dtype, reduction=reduction).to( device=device @@ -596,6 +835,7 @@ def check(): ref_lm_head_ce.lm_head.weight.data = weight.data liger_lm_head_ce.lm_head.weight.data = weight.data liger_chunk_lm_head_ce.lm_head.weight.data = weight.data + liger_lm_head_ce_grad_in_fwd.lm_head.weight.data = weight.data cce_lm_head_ce.lm_head.weight.data = weight.data triton_liger_lm_head_ce.lm_head.weight.data = weight.data @@ -606,6 +846,7 @@ def fwd_bwd_fn(input, target, fn): liger_lm_head_ce_fwd_bwd = partial(fwd_bwd_fn, fn=liger_lm_head_ce) liger_chunk_lm_head_ce_fwd_bwd = partial(fwd_bwd_fn, fn=liger_chunk_lm_head_ce) + liger_lm_head_ce_grad_in_fwd_full = partial(fwd_bwd_fn, fn=liger_lm_head_ce_grad_in_fwd) ref_lm_head_ce_fwd_bwd = partial(fwd_bwd_fn, fn=ref_lm_head_ce) cce_lm_head_ce_fwd_bwd = partial(fwd_bwd_fn, fn=cce_lm_head_ce) triton_liger_lm_head_ce_fwd_bwd = partial(fwd_bwd_fn, fn=triton_liger_lm_head_ce) @@ -613,7 +854,11 @@ def fwd_bwd_fn(input, target, fn): # Test and Benchmark run_example( - liger_lm_head_ce, + { + # "helion_fwd_bwd_cce": liger_lm_head_ce_fwd_bwd, # nan + "helion_fwd": liger_lm_head_ce, + "helion_grad_in_fwd": liger_lm_head_ce_grad_in_fwd, + }, { "torch_fwd": ref_lm_head_ce, "cce_fwd": cce_lm_head_ce, @@ -629,6 +874,7 @@ def fwd_bwd_fn(input, target, fn): { # "helion_fwd_bwd_cce": liger_lm_head_ce_fwd_bwd, # nan "helion_fwd_bwd_chunk": liger_chunk_lm_head_ce_fwd_bwd, + "helion_grad_in_fwd": liger_lm_head_ce_grad_in_fwd_full, # There is a constant overhead after fwd & bwd pass }, { "torch_fwd_bwd": ref_lm_head_ce_fwd_bwd, @@ -675,4 +921,3 @@ def fwd_bwd_fn(input, target, fn): if args.benchmark: print("test correctness and benchmark all implementations") check() - diff --git a/test/transformers/helion/test_fused_linear_cross_entropy.py b/test/transformers/helion/test_fused_linear_cross_entropy.py index a0a935e0d..27f377459 100644 --- a/test/transformers/helion/test_fused_linear_cross_entropy.py +++ b/test/transformers/helion/test_fused_linear_cross_entropy.py @@ -11,6 +11,7 @@ device = infer_device() + def supports_bfloat16(): if device == "cuda": return torch.cuda.get_device_capability() >= (8, 0) # Ampere and newer @@ -19,6 +20,7 @@ def supports_bfloat16(): else: return False + def set_seed(seed=42): """ Fix all random seeds we use for reproducibility. @@ -45,9 +47,11 @@ def set_seed(seed=42): # Python hash seed os.environ["PYTHONHASHSEED"] = str(seed) - + + set_seed(42) + class TorchLMHeadCE(torch.nn.Module): def __init__( self, @@ -77,11 +81,14 @@ def __init__( ): super().__init__() self.lm_head = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype) - self.flce = LigerFusedLinearCrossEntropyHelion(ignore_index=ignore_index, reduction=reduction) + self.flce = LigerFusedLinearCrossEntropyHelion( + ignore_index=ignore_index, reduction=reduction, grad_in_forward=True + ) def forward(self, x, target): return self.flce(x, self.lm_head.weight, target) + @pytest.mark.parametrize( "B, T, H, V", [ @@ -136,7 +143,8 @@ def test_fused_linear_cross_entropy_correctness(B, T, H, V, reduction, dtype, at assert liger_input.grad.isinf().sum() == 0 torch.testing.assert_close(liger_input.grad, ref_input.grad, rtol=rtol, atol=atol) torch.testing.assert_close( - liger_lm_head_ce.lm_head.weight.grad, ref_lm_head_ce.lm_head.weight.grad, rtol=rtol, atol=atol, + liger_lm_head_ce.lm_head.weight.grad, + ref_lm_head_ce.lm_head.weight.grad, + rtol=rtol, + atol=atol, ) - -