Skip to content

Commit 7a4cfe7

Browse files
wenqinyJokeren
andauthored
[INTERPRETER] Fix tl.max consistency issue (#7349)
# `tl.max` shows different behavior against `torch.max` and interpreter Related issue: #6635 ## Summary I found `tl.max()` method didn't take `nan` as its output when there is any `nan` in the input tensor, but both `torch.max` and interpreter will output `nan` if it existed in the input. it looks like: ``` tl.max("nan", "inf") = "inf" torch.max("nan", inf") = "nan" ``` I thought it may bring some data consistency issue with torch, just like the issue I mentioned above. ## Case <details> <summary> Simple repro </summary> ``` import triton import triton.language as tl import torch @triton.jit def max_kernel(X_ptr, Y_ptr, BLOCK_SIZE: tl.constexpr): offsets = tl.arange(0, BLOCK_SIZE) # [0, 1, ..., 63] x = tl.load(X_ptr + offsets) # Load 64 elements max_val = tl.max(x, axis=0) # Compute max across all 64 elements # Only one block (e.g., block 0) writes out the result if tl.program_id(0) == 0: tl.store(Y_ptr, max_val) inf = float("Inf") # inf = float(999999) nan = float("nan") # a tensor contains one "nan" scalar and other scalars is "inf" or any finite float number. x = torch.tensor( [ nan, inf, inf, inf, inf, inf, inf, inf, # 8 elements a row inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,], device='cuda:0') # Output tensor: just one element to hold the max y = torch.empty(1, device='cuda') # Launch the kernel with 1 block max_kernel[1,]( x, y, BLOCK_SIZE=64 ) # Fetch the result print("Max value triton:", y[0].item()) print("Max value torch:", torch.max(x).item()) # for validation ``` </details> When we running the above code with input as **a tensor contains one "nan" scalar and other scalars is "inf" or any finite float number**, its output is: ``` Max value triton: inf Max value torch: nan ``` If we run trition in interpreter mode, the output is: ``` Max value triton: nan Max value torch: nan ``` We could see **triton show different behavior with torch and interpreter**. ## Root cause I though the root cause for this is for `_elementwise_max` it will generate a [`arith.maxnumf`](https://mlir.llvm.org/docs/Dialects/ArithOps/#arithmaxnumf-arithmaxnumfop) ir by defult, it means **"If one of the arguments is NaN, then the result is the other argument."** https://github.com/triton-lang/triton/blob/3b41514dc2526628deadbe5271b5596ffa2fb820/python/triton/language/semantic.py#L380-L381 ## Next steps This PR is still in a draft stage, but if it makes sense to you, I will try to fix some test (I guess there may be some test failed) and complement it for `tl.min` too. Separately, I've noticed `tl.argmax` also exhibits an inconsistency with torch and the interpreter. I plan to address this in a follow-up PR if this one is finalized. <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `It's in RFC stage, will add test later`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) --------- Co-authored-by: Keren Zhou <[email protected]>
1 parent 00d5ca7 commit 7a4cfe7

File tree

2 files changed

+45
-2
lines changed

2 files changed

+45
-2
lines changed

python/test/unit/language/test_core.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2422,6 +2422,49 @@ def kernel(X, Z, BLOCK: tl.constexpr):
24222422
assert z[0] == 0
24232423

24242424

2425+
@pytest.mark.interpreter
2426+
def test_max_min_with_nan(device):
2427+
# In triton, we implement a "nan ignore" style, which means if there is NaN
2428+
# in the reduce dimesion, we should ignore it and return the max/min number,
2429+
# it's different with torch.max/min.
2430+
@triton.jit
2431+
def max_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr):
2432+
offsets = tl.arange(0, BLOCK_SIZE)
2433+
x = tl.load(x_ptr + offsets)
2434+
2435+
max_val = tl.max(x, axis=0)
2436+
2437+
if tl.program_id(0) == 0:
2438+
tl.store(y_ptr, max_val)
2439+
2440+
@triton.jit
2441+
def min_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr):
2442+
offsets = tl.arange(0, BLOCK_SIZE)
2443+
x = tl.load(x_ptr + offsets)
2444+
2445+
min_val = tl.min(x, axis=0)
2446+
2447+
if tl.program_id(0) == 0:
2448+
tl.store(y_ptr, min_val)
2449+
2450+
BLOCK_SIZE = 64
2451+
x = torch.rand((1, BLOCK_SIZE), dtype=torch.float32, device=device)
2452+
# Not the expected output for tl.max
2453+
x[0, 0] = float('nan')
2454+
# Expected output for tl.min
2455+
x[0, 1] = float('-inf')
2456+
# Expected output for tl.max
2457+
x[0, 2] = float('inf')
2458+
2459+
y = torch.ones(1, device=device)
2460+
2461+
max_kernel[(1, )](x, y, BLOCK_SIZE=BLOCK_SIZE)
2462+
assert y[0] == float('inf')
2463+
2464+
min_kernel[(1, )](x, y, BLOCK_SIZE=BLOCK_SIZE)
2465+
assert y[0] == float('-inf')
2466+
2467+
24252468
def get_reduced_dtype(dtype_str, op):
24262469
if op in ('argmin', 'argmax'):
24272470
return 'int32'

python/triton/runtime/interpreter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -934,9 +934,9 @@ def apply_impl(self, input):
934934
elif self.combine_fn == tl.standard._argmax_combine_tie_break_left:
935935
return self.min_max(input[0], val_reduce_op=np.max, idx_reduce_op=np.argmax)
936936
elif self.combine_fn == tl.standard._elementwise_max:
937-
return self.min_max(input[0], val_reduce_op=np.max, idx_reduce_op=None)
937+
return self.min_max(input[0], val_reduce_op=np.nanmax, idx_reduce_op=None)
938938
elif self.combine_fn == tl.standard._elementwise_min:
939-
return self.min_max(input[0], val_reduce_op=np.min, idx_reduce_op=None)
939+
return self.min_max(input[0], val_reduce_op=np.nanmin, idx_reduce_op=None)
940940
elif self.combine_fn == tl.standard._sum_combine:
941941
return self.sum(input[0])
942942
else:

0 commit comments

Comments
 (0)