-
Notifications
You must be signed in to change notification settings - Fork 59
Description
I’m reporting an issue with challenge 65_geglu (Gaussian Error Gated Linear Unit): a correct pure Triton implementation fails due to a consistent small numerical mismatch vs the reference implementation, even when matching the same formula.
Reference implementation (from challenge)
The grader reference is:
x1, x2 = input.chunk(2)
gelu = 0.5 * x2 * (1.0 + torch.erf(x2 / math.sqrt(2.0)))
output.copy_(x1 * gelu)with tolerances:
atol=1e-05rtol=1e-05
Problem
A pure Triton kernel computing the same formula using tl.erf fails with:
max abs diff = 3.0517578125e-05
This happens with:
- fp32 compute
- fp64 intermediates (still fails after casting to fp32 output)
- hardcoded constants 0.7071067811865476
This strongly suggests tl.erf does not numerically match torch.erf on CUDA closely enough to satisfy 1e-5 tolerance for all random inputs (e.g., N=1024 uniform(-100,100)).
Why this matters
On the Triton leaderboard, the few “Triton” solutions appear to pass by bypassing Triton and using PyTorch ops instead. That implies the task is effectively not solvable with pure Triton under current grading tolerances/reference.
Repro (pure Triton)
import torch
import triton
import triton.language as tl
@triton.jit
def geglu(input, output, N, BLOCK_SIZE: tl.constexpr):
middle = N // 2
idx = tl.program_id(0)
offset = idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offset < middle
x = tl.load(input + offset, mask=mask, other=0.0)
y = tl.load(input + offset + middle, mask=mask, other=0.0)
res = x * 0.5 * y * (1.0 + tl.erf(y * tl.sqrt(2.0) * 0.5))
tl.store(output + offset, res, mask=mask)
# input, output are tensors on the GPU
def solve(input: torch.Tensor, output: torch.Tensor, N: int):
BLOCK_SIZE = 256
grid = (triton.cdiv(N // 2, BLOCK_SIZE),)
geglu[grid](input, output, N, BLOCK_SIZE=BLOCK_SIZE)Request / Possible fixes
Could you please consider one of:
- Relax tolerance for this challenge (e.g.
atol=5e-5or1e-4), since different GPU erf implementations can differ slightly. - Change reference to a formulation that matches Triton’s available math more robustly (or compare against a CPU high-precision reference with looser tolerance).
- Document that the Triton track is allowed to call
torch.erf(if that’s intended), though it defeats the purpose of a Triton-only challenge.
Extra: what to change (suggested tolerance)
Given the observed diff is ~3.05e-05, a minimal safe fix is:
atol = 5e-05(or1e-4to be robust across GPUs)