Skip to content

Commit 452205e

Browse files
committed
[Example] Add fused_linear_cross_entropy example and unit test
1 parent 41fe6e9 commit 452205e

File tree

3 files changed

+373
-0
lines changed

3 files changed

+373
-0
lines changed
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
"""Fused linear cross entropy implementation for Helion.
2+
3+
This implementation uses Liger's chunking strategy to reduce memory usage.
4+
"""
5+
6+
from __future__ import annotations
7+
8+
import os
9+
10+
import torch
11+
12+
import helion
13+
from helion._testing import run_example
14+
import helion.language as hl
15+
16+
# TritonBench configuration
17+
if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1":
18+
# Low memory configuration
19+
TRITONBENCH_ARGS = {"hidden_size": 2048, "vocab_size": 32000}
20+
21+
# Maximum chunk size (similar to Liger's MAX_FUSED_SIZE)
22+
MAX_FUSED_SIZE = 65536 // 2
23+
24+
25+
@helion.kernel(static_shapes=True)
26+
def cross_entropy_kernel(
27+
logits_chunk: torch.Tensor, # [chunk_size, vocab_size]
28+
target_chunk: torch.Tensor, # [chunk_size]
29+
loss_chunk: torch.Tensor, # [chunk_size]
30+
chunk_size: int,
31+
vocab_size: int,
32+
n_total_samples: int, # Total number of samples for mean reduction
33+
) -> None:
34+
# Grid over samples - each program handles one sample
35+
for program_id in hl.grid(chunk_size):
36+
target_idx = target_chunk[program_id].unsqueeze(0)
37+
38+
# Online softmax: first pass - find max and sum
39+
m = hl.full([], float("-inf")) # max value
40+
d = hl.full([], 0.0) # sum of exp
41+
42+
# Store original value at target
43+
ori_logit_y = logits_chunk[program_id, target_idx]
44+
45+
# Process in blocks like Liger
46+
for vocab_tile in hl.tile(vocab_size):
47+
# Create block offsets (like tl.arange in Triton)
48+
block_offsets = vocab_tile.index
49+
50+
# Masked load of block
51+
mask = block_offsets < vocab_size
52+
logits_block = torch.where(
53+
mask, logits_chunk[program_id, block_offsets], float("-inf")
54+
)
55+
56+
# Find block max
57+
block_max = torch.max(logits_block)
58+
59+
# Online softmax update
60+
m_new = torch.maximum(m, block_max)
61+
d = d * torch.exp(m - m_new) + torch.sum(torch.exp(logits_block - m_new))
62+
m = m_new
63+
64+
# Compute log-sum-exp
65+
lse = m + torch.log(d)
66+
loss = lse - ori_logit_y
67+
# Apply mean reduction inside the kernel
68+
loss_chunk[program_id] = (loss / n_total_samples).squeeze(0)
69+
70+
# Second pass: compute gradients with block processing
71+
for vocab_tile in hl.tile(vocab_size):
72+
block_offsets = vocab_tile.index
73+
mask = block_offsets < vocab_size
74+
75+
# Load block
76+
logits_block = torch.where(
77+
mask, logits_chunk[program_id, block_offsets], 0.0
78+
)
79+
80+
# Compute softmax
81+
softmax_block = torch.exp(logits_block - m) / d
82+
83+
# Special handling for target
84+
is_target_block = block_offsets == target_idx
85+
grad_block = torch.where(
86+
is_target_block, softmax_block - 1.0, softmax_block
87+
)
88+
89+
# Apply mean reduction to gradients
90+
grad_block = grad_block / n_total_samples
91+
92+
# Masked store using torch.where pattern
93+
# First, load existing values for positions that will be masked out
94+
existing_values = logits_chunk[program_id, block_offsets]
95+
96+
# Apply mask to the gradient block
97+
logits_chunk[program_id, block_offsets] = torch.where(
98+
mask, grad_block, existing_values
99+
)
100+
101+
102+
def fused_linear_cross_entropy_forward(
103+
_input: torch.Tensor,
104+
weight: torch.Tensor,
105+
target: torch.Tensor,
106+
bias: torch.Tensor | None = None,
107+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
108+
"""Forward pass with chunking strategy similar to Liger."""
109+
device = _input.device
110+
BT, H = _input.shape
111+
V = weight.shape[0]
112+
113+
# Calculate chunk size to limit memory usage
114+
inc_factor = (V + H - 1) // H
115+
chunk_size = min(MAX_FUSED_SIZE, (BT + inc_factor - 1) // inc_factor)
116+
chunk_size = min(chunk_size, BT)
117+
num_chunks = (BT + chunk_size - 1) // chunk_size
118+
119+
# Initialize gradients and loss
120+
grad_input = torch.zeros_like(_input)
121+
grad_weight = torch.zeros_like(weight) if weight.requires_grad else None
122+
grad_bias = torch.zeros_like(bias) if bias is not None else None
123+
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
124+
125+
# Process in chunks
126+
for chunk_id in range(num_chunks):
127+
start_idx = chunk_id * chunk_size
128+
end_idx = min((chunk_id + 1) * chunk_size, BT)
129+
actual_chunk_size = end_idx - start_idx
130+
131+
# Get chunk of input and target
132+
input_chunk = _input[start_idx:end_idx] # [chunk_size, H]
133+
target_chunk = target[start_idx:end_idx] # [chunk_size]
134+
135+
# Compute logits for this chunk
136+
logits_chunk = input_chunk @ weight.t() # [chunk_size, V]
137+
if bias is not None:
138+
logits_chunk = logits_chunk + bias
139+
140+
# Ensure contiguous for kernel
141+
logits_chunk = logits_chunk.contiguous()
142+
target_chunk = target_chunk.contiguous()
143+
144+
# Get loss slice
145+
loss_chunk = loss_1d[start_idx:end_idx]
146+
147+
# Call kernel - this modifies logits_chunk in-place to contain gradients
148+
cross_entropy_kernel(
149+
logits_chunk,
150+
target_chunk,
151+
loss_chunk,
152+
actual_chunk_size,
153+
V,
154+
BT, # Pass total number of samples for mean reduction
155+
)
156+
157+
# Now logits_chunk contains gradients
158+
# Compute input gradient: grad_input = grad_logits @ weight
159+
grad_input[start_idx:end_idx] = logits_chunk.detach() @ weight.detach()
160+
161+
# Accumulate weight gradients if needed
162+
if grad_weight is not None:
163+
# grad_weight += grad_logits.T @ input
164+
# Detach tensors to avoid autograd issues with in-place operations
165+
torch.addmm(
166+
input=grad_weight,
167+
mat1=logits_chunk.detach().t(),
168+
mat2=input_chunk.detach(),
169+
out=grad_weight,
170+
alpha=1.0,
171+
beta=1.0,
172+
)
173+
174+
if grad_bias is not None:
175+
torch.add(
176+
input=grad_bias,
177+
other=logits_chunk.detach().sum(dim=0),
178+
out=grad_bias,
179+
alpha=1.0,
180+
)
181+
182+
# Return total loss
183+
loss = loss_1d.sum()
184+
185+
return loss, grad_input, grad_weight, grad_bias
186+
187+
188+
# User-facing function
189+
def fused_linear_cross_entropy(
190+
input_tensor: torch.Tensor,
191+
weight: torch.Tensor,
192+
labels: torch.Tensor,
193+
bias: torch.Tensor | None = None,
194+
) -> torch.Tensor:
195+
"""Fused linear + cross entropy."""
196+
loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
197+
input_tensor, weight, labels, bias
198+
)
199+
200+
# For this example, we just return the loss
201+
# In a real implementation with autograd, we'd save gradients for backward
202+
return loss
203+
204+
205+
def fused_linear_cross_entropy_pytorch(
206+
input_tensor: torch.Tensor,
207+
weight: torch.Tensor,
208+
labels: torch.Tensor,
209+
bias: torch.Tensor | None = None,
210+
) -> torch.Tensor:
211+
"""PyTorch reference implementation for fused linear cross entropy."""
212+
# Compute logits
213+
logits = torch.matmul(input_tensor, weight.T)
214+
if bias is not None:
215+
logits = logits + bias
216+
# Compute cross entropy
217+
return torch.nn.functional.cross_entropy(logits, labels)
218+
219+
220+
def main() -> None:
221+
n, h, v = 128, 512, 1000
222+
torch.manual_seed(42)
223+
input_tensor = torch.randn(n, h, device="cuda", dtype=torch.float32)
224+
weight = torch.randn(v, h, device="cuda", dtype=torch.float32)
225+
labels = torch.randint(0, v, (n,), device="cuda", dtype=torch.long)
226+
227+
run_example(
228+
fused_linear_cross_entropy,
229+
fused_linear_cross_entropy_pytorch,
230+
(input_tensor, weight, labels),
231+
kernel_name="helion",
232+
baseline_name="torch",
233+
rtol=1e-3,
234+
atol=1e-3,
235+
)
236+
237+
238+
if __name__ == "__main__":
239+
main()

test/test_examples.expected

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,89 @@ def fp8_gemm(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
705705
_launcher(_fp8_gemm_kernel, (triton.cdiv(256, _BLOCK_SIZE_0) * triton.cdiv(256, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
706706
return out
707707

708+
--- assertExpectedJournal(TestExamples.test_fused_linear_cross_entropy)
709+
from __future__ import annotations
710+
711+
import torch
712+
import triton
713+
import triton.language as tl
714+
from helion.runtime import default_launcher as _default_launcher
715+
716+
@triton.jit
717+
def _linear_kernel(input_1, weight, logits, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
718+
num_blocks_0 = tl.cdiv(128, _BLOCK_SIZE_0)
719+
pid_0 = tl.program_id(0) % num_blocks_0
720+
pid_1 = tl.program_id(0) // num_blocks_0
721+
offset_0 = pid_0 * _BLOCK_SIZE_0
722+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
723+
offset_1 = pid_1 * _BLOCK_SIZE_1
724+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
725+
mask_1 = indices_1 < 1000
726+
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
727+
for offset_2 in tl.range(0, 512, _BLOCK_SIZE_2):
728+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
729+
acc_copy = acc
730+
acc_copy_0 = acc_copy
731+
load = tl.load(input_1 + (indices_0[:, None] * 512 + indices_2[None, :] * 1), None)
732+
load_1 = tl.load(weight + (indices_1[:, None] * 512 + indices_2[None, :] * 1), mask_1[:, None], other=0)
733+
permute = tl.permute(load_1, [1, 0])
734+
acc = tl.dot(load, permute, acc=acc_copy_0, input_precision='ieee')
735+
tl.store(logits + (indices_0[:, None] * 1000 + indices_1[None, :] * 1), acc, mask_1[None, :])
736+
737+
def linear(input: torch.Tensor, weight: torch.Tensor, *, _launcher=_default_launcher):
738+
n, h = input.shape
739+
v, h2 = weight.shape
740+
assert h == h2, f'Hidden size mismatch: {h} != {h2}'
741+
logits = torch.empty([n, v], dtype=torch.float32, device=input.device)
742+
_BLOCK_SIZE_0 = 16
743+
_BLOCK_SIZE_1 = 16
744+
_BLOCK_SIZE_2 = 16
745+
_launcher(_linear_kernel, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(1000, _BLOCK_SIZE_1),), input, weight, logits, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
746+
return logits
747+
748+
--- assertExpectedJournal(TestExamples.test_fused_linear_cross_entropy)
749+
from __future__ import annotations
750+
751+
import torch
752+
import triton
753+
import triton.language as tl
754+
from torch._inductor.runtime.triton_helpers import math as tl_math
755+
from helion.runtime import default_launcher as _default_launcher
756+
757+
@triton.jit
758+
def _cross_entropy_loss_kernel(labels, base_indices, logits_flat, logits, losses, base_indices_stride_0, labels_stride_0, logits_stride_0, logits_stride_1, logits_flat_stride_0, losses_stride_0, v, _RDIM_SIZE_1: tl.constexpr):
759+
pid_0 = tl.program_id(0)
760+
offset_0 = pid_0
761+
indices_0 = offset_0 + tl.zeros([1], tl.int32)
762+
indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
763+
mask_1 = indices_1 < v
764+
labels_tile = tl.load(labels + indices_0 * labels_stride_0, None)
765+
base_indices_tile = tl.load(base_indices + indices_0 * base_indices_stride_0, None)
766+
v_0 = base_indices_tile + labels_tile
767+
logits_at_target = tl.load(logits_flat + v_0 * logits_flat_stride_0, None)
768+
logits_rows = tl.load(logits + (indices_0[:, None] * logits_stride_0 + indices_1[None, :] * logits_stride_1), mask_1[None, :], other=0)
769+
_mask_to = tl.where(tl.broadcast_to(mask_1[None, :], [1, _RDIM_SIZE_1]), logits_rows, float('-inf'))
770+
max_logits = tl.reshape(tl.max(_mask_to, 1), [1, 1])
771+
v_1 = logits_rows - max_logits
772+
v_2 = tl_math.exp(v_1)
773+
_mask_to_1 = tl.where(tl.broadcast_to(mask_1[None, :], [1, _RDIM_SIZE_1]), v_2, 0)
774+
sum_exp = tl.reshape(tl.sum(_mask_to_1, 1), [1, 1])
775+
squeeze = tl.reshape(max_logits, [1])
776+
squeeze_1 = tl.reshape(sum_exp, [1])
777+
v_3 = tl_math.log(squeeze_1)
778+
v_4 = squeeze + v_3
779+
v_5 = v_4 - logits_at_target
780+
tl.store(losses + indices_0 * losses_stride_0, v_5, None)
781+
782+
def cross_entropy_loss(logits: torch.Tensor, labels: torch.Tensor, *, _launcher=_default_launcher):
783+
n, v = logits.shape
784+
losses = torch.zeros([n], dtype=torch.float32, device=logits.device)
785+
base_indices = torch.arange(n, device=logits.device) * v
786+
logits_flat = logits.view(-1)
787+
_RDIM_SIZE_1 = triton.next_power_of_2(v)
788+
_launcher(_cross_entropy_loss_kernel, (n,), labels, base_indices, logits_flat, logits, losses, base_indices.stride(0), labels.stride(0), logits.stride(0), logits.stride(1), logits_flat.stride(0), losses.stride(0), v, _RDIM_SIZE_1, num_warps=4, num_stages=3)
789+
return losses.mean()
790+
708791
--- assertExpectedJournal(TestExamples.test_jagged_dense_add)
709792
from __future__ import annotations
710793

test/test_examples.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,57 @@ def test_cross_entropy(self):
281281
)
282282
)
283283

284+
def test_fused_linear_cross_entropy(self):
285+
# Test the fused kernel
286+
n, h, v = 128, 512, 1000
287+
input_tensor = torch.randn(n, h, device=DEVICE, dtype=torch.float32)
288+
weight = torch.randn(v, h, device=DEVICE, dtype=torch.float32)
289+
labels = torch.randint(0, v, (n,), device=DEVICE, dtype=torch.long)
290+
291+
args = (input_tensor, weight, labels)
292+
# Compute expected loss using PyTorch
293+
logits = torch.matmul(input_tensor, weight.T)
294+
expected_loss = torch.nn.functional.cross_entropy(logits, labels)
295+
296+
self.assertExpectedJournal(
297+
check_example(
298+
"fused_linear_cross_entropy",
299+
args,
300+
expected_loss,
301+
fn_name="fused_linear_cross_entropy_kernel",
302+
)
303+
)
304+
305+
# Also test the individual kernels for backward compatibility
306+
# Test the linear kernel
307+
linear_args = (input, weight)
308+
expected_logits = torch.matmul(input, weight.T)
309+
310+
self.assertExpectedJournal(
311+
check_example(
312+
"fused_linear_cross_entropy",
313+
linear_args,
314+
expected_logits,
315+
fn_name="linear",
316+
)
317+
)
318+
319+
# Test the cross_entropy_loss kernel
320+
logits = torch.randn(n, v, device=DEVICE, dtype=torch.float32)
321+
labels2 = torch.randint(0, v, (n,), device=DEVICE, dtype=torch.long)
322+
323+
ce_args = (logits, labels2)
324+
expected_loss2 = torch.nn.functional.cross_entropy(logits, labels2)
325+
326+
self.assertExpectedJournal(
327+
check_example(
328+
"fused_linear_cross_entropy",
329+
ce_args,
330+
expected_loss2,
331+
fn_name="cross_entropy_loss",
332+
)
333+
)
334+
284335
def test_rms_norm(self):
285336
args = (
286337
torch.randn([128, 256], device=DEVICE, dtype=torch.float16),

0 commit comments

Comments
 (0)