Commit 6b3d761
authored
[Fix] Replace conditional flow with
## Summary
Fixes a CompilationError in `liger_cross_entropy_kernel` when using
`RETURN_TOKEN_ACCURACY=True` on Triton 3.2.0 (PyTorch 2.6).
The current implementation uses a Python conditional statement (`if
RETURN_TOKEN_ACCURACY and block_max > m`) with a runtime tensor
condition. This causes a ValueError during JIT compilation on stricter
or older Triton versions, as the compiler attempts to cast a tensor to a
boolean. This PR replaces the conditional control flow with `tl.where`
to ensure robust, data-dependent selection compatible with Triton best
practices.
## Details
The Issue: On environments with Triton 3.2.0, running the cross-entropy
kernel (`pytest test/transformers/test_cross_entropy.py`) triggers the
following error:
```txt
triton.compiler.errors.CompilationError: at 120:11:
...
def make_ir(self, options, codegen_fns, module_map, context):
> return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
module_map=module_map)
E triton.compiler.errors.CompilationError: at 120:11:
E X_block = tl.load(
E X_ptr + X_offsets,
E mask=X_offsets < n_cols,
E other=float("-inf"),
E # Ensure float32 precision for softmax calculation
E ).cast(tl.float32)
E if HAS_SOFTCAPPING:
E X_block = softcap * tanh(X_block / softcap)
E block_max = tl.max(X_block)
E
E # Track argmax for accuracy computation
E if RETURN_TOKEN_ACCURACY and block_max > m:
E ^
E ValueError('Cannot bitcast data-type of size 32 to data-type of size 1')
```
The Fix: I replaced the nested Python `if` statement with a `tl.where`
operation.
* Before: Relied on implicit compiler predication for `if block_max >
m`, which fails when the compiler strictly evaluates the Python boolean
context.
* After: Explicitly calculates the `argmax` for the current block and
uses `tl.where(is_new_max, ...)` to update the global index. This is the
canonical way to handle tensor-dependent logic in Triton.
## Testing Done
- Hardware Type: NVIDIA A100-SXM4-80GB
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergencetl.where in liger_cross_entropy_kernel for Triton 3.2 compatibility (#991)1 parent 01a66f1 commit 6b3d761
1 file changed
+5
-2
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
143 | 143 | | |
144 | 144 | | |
145 | 145 | | |
146 | | - | |
| 146 | + | |
147 | 147 | | |
148 | 148 | | |
149 | 149 | | |
150 | 150 | | |
151 | 151 | | |
152 | | - | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
153 | 156 | | |
154 | 157 | | |
155 | 158 | | |
| |||
0 commit comments