Skip to content

Commit 6b3d761

Browse files
authored
[Fix] Replace conditional flow with tl.where in liger_cross_entropy_kernel for Triton 3.2 compatibility (#991)
## Summary Fixes a CompilationError in `liger_cross_entropy_kernel` when using `RETURN_TOKEN_ACCURACY=True` on Triton 3.2.0 (PyTorch 2.6). The current implementation uses a Python conditional statement (`if RETURN_TOKEN_ACCURACY and block_max > m`) with a runtime tensor condition. This causes a ValueError during JIT compilation on stricter or older Triton versions, as the compiler attempts to cast a tensor to a boolean. This PR replaces the conditional control flow with `tl.where` to ensure robust, data-dependent selection compatible with Triton best practices. ## Details The Issue: On environments with Triton 3.2.0, running the cross-entropy kernel (`pytest test/transformers/test_cross_entropy.py`) triggers the following error: ```txt triton.compiler.errors.CompilationError: at 120:11: ... def make_ir(self, options, codegen_fns, module_map, context): > return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns, module_map=module_map) E triton.compiler.errors.CompilationError: at 120:11: E X_block = tl.load( E X_ptr + X_offsets, E mask=X_offsets < n_cols, E other=float("-inf"), E # Ensure float32 precision for softmax calculation E ).cast(tl.float32) E if HAS_SOFTCAPPING: E X_block = softcap * tanh(X_block / softcap) E block_max = tl.max(X_block) E E # Track argmax for accuracy computation E if RETURN_TOKEN_ACCURACY and block_max > m: E ^ E ValueError('Cannot bitcast data-type of size 32 to data-type of size 1') ``` The Fix: I replaced the nested Python `if` statement with a `tl.where` operation. * Before: Relied on implicit compiler predication for `if block_max > m`, which fails when the compiler strictly evaluates the Python boolean context. * After: Explicitly calculates the `argmax` for the current block and uses `tl.where(is_new_max, ...)` to update the global index. This is the canonical way to handle tensor-dependent logic in Triton. ## Testing Done - Hardware Type: NVIDIA A100-SXM4-80GB - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence
1 parent 01a66f1 commit 6b3d761

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

src/liger_kernel/ops/cross_entropy.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,16 @@ def liger_cross_entropy_kernel(
143143
block_max = tl.max(X_block)
144144

145145
# Track argmax for accuracy computation
146-
if RETURN_TOKEN_ACCURACY and block_max > m:
146+
if RETURN_TOKEN_ACCURACY:
147147
# Find the index of the maximum value in this block
148148
is_max_mask = X_block == block_max
149149
# Mask out invalid indices with a value larger than n_cols
150150
masked_offsets = tl.where(is_max_mask, X_offsets, n_cols)
151151
# Get the first (smallest) index where max occurs
152-
argmax_idx = tl.min(masked_offsets)
152+
current_block_argmax_idx = tl.min(masked_offsets)
153+
154+
is_new_max = block_max > m
155+
argmax_idx = tl.where(is_new_max, current_block_argmax_idx, argmax_idx)
153156

154157
if label_smoothing > 0:
155158
# scale X beforehand to avoid overflow

0 commit comments

Comments
 (0)