Skip to content

Conversation

@whitneywhtsang
Copy link
Contributor

@whitneywhtsang whitneywhtsang commented Nov 14, 2025

This PR changes the Triton base from 3c2e6f8 to c186592 (Oct 29).
Pass rate: 94.91%->94.95%

ptillet and others added 8 commits October 29, 2025 00:18
…nerically (#8421) (#8495)

This PR relands triton-lang/triton#8386.
It depends on triton-lang/triton#8492 to avoid
regressing in some workloads.
There's silent data corruption when calling `tl.histogram` with
interpreter.

```python
# test.py
import torch
import ctypes
import triton
import triton.language as tl


@triton.jit
def histogram_kernel(x_ptr, z_ptr):
    offset = tl.arange(0, 1)
    x = tl.load(x_ptr + offset)
    z = tl.histogram(x, 1)
    buf = (ctypes.c_int32 * 2).from_address(int(z_ptr))

    print(f'before store: {list(buf)}')
    tl.store(z_ptr + offset, z) # tl.store treats z values as int64 while they're int32
    print(f'after store: {list(buf)}')


device = 'cpu'
torch.manual_seed(17)
x = torch.ones(1, device=device, dtype=torch.int32)
z = torch.ones(2, dtype=torch.int32, device=device)
histogram_kernel[(1, )](x, z)

# Output:
# TRITON_INTERPRET=1 TRITON_TEST_SUITE=interpreter python test.py 
# before store: [1, 1]
# after store: [1, 0] <- second element shouldn't be cleared
```

Based on `np.histogram` docs:
https://numpy.org/doc/2.3/reference/generated/numpy.histogram.html
Returned dtype is taken account when optional weights param is passed,
int64 othwerwise.
That leads to `tl.store` thinking it's saving int64 values while there's
int32 in my example tensor passed, so it's writing 8 bytes at once
instead of 4 bytes, leading to writing 4 bytes exceeding it's data range
causing silent data corruption.

```python
import numpy as np

data = np.array([1], dtype=np.int32)
bins = 1

print(f'Data dtype before: {data.dtype}')
histogram = np.histogram(data, bins=bins, range=(0, bins))[0]
print(f'Data dtype after: {histogram.dtype}')

# Data dtype before: int32                                                                                                                                           
# Data dtype after: int64
```

Applying "dummy_weights" fixes returned data type as expected fixing
data corruption.

------------------------------

<!---
The core Triton is a small number of people, and we receive many PRs
(thank
you!).  To help us review your code more quickly, **if you are a new
contributor (less than 3 PRs merged) we ask that you complete the
following
tasks and include the filled-out checklist in your PR description.**

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->

# New contributor declaration
- [x] I am not making a trivial change, such as fixing a typo in a
comment.

- [x] I have written a PR description following these
  [rules](https://cbea.ms/git-commit/#why-not-how).

- [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.

- Select one of the following.
  - [ ] I have added tests.
    - `/test` for `lit` tests
    - `/unittest` for C++ tests
    - `/python/test` for end-to-end tests
- [x] This PR does not need a test because np.histogram specific
behavior with interpreter mode.

- Select one of the following.
  - [x] I have not added any `lit` tests.
- [ ] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
    and using the instructions it generates is not minimal.)
… transactions (#8575)

`ttg.async_wait` counts the number of outstanding `ttg.commit_groups`.
However, when lowering to LLVM on AMD we require the number of
outstanding async intrinsics/final assembly instructions. The conversion
is already done by `UpdateAsyncWaitCnt` which modifies the `num` of
`ttg.async_wait` in place.
This PR introduces a new op `amdgpu.async_wait` to make the change in
semantics explicit in the IR.

`UpdateAsyncWaitCount` is moved to `TTGIR->LLVM` primarily to also
include in for `Gluon` kernels and we should always call it since it
will only have an effect if there are `ttg.async_wait` ops present in
the kernel.

To avoid membar changes this also adds a `ttgpu.LocalBarrier` after each
`amdgpu.async_wait`. Membar will respect the newly added barrier and
behave the same as for `ttg.async_wait`.
Fixes #8578

We're using the wrong output constraint which leads llvm to extend the
fp16 value to 32-bits. Fixing the constraint removes the conversion.

Note that we still end up with a no-op sequence like:
```ptx
mov.b32 {%rs1, %rs2}, %r1
mov.b32 %r2, {%rs1, %rs2}
```

However, `ptxas` is able to optimize these out.
### The Problem with the Original Formula
The original formula is:
```
tanh(x) = (e^(2x) - 1) / (e^(2x) + 1)
```
- Issue with large positive x:
   - When x = 20: e^(40) ≈ 2.4 × 10^17 → manageable
   - When x = 50: e^(100) ≈ 2.7 × 10^43 → overflow to infinity
   - Result: (∞ - 1)/(∞ + 1) = NaN x
- For negative x: The formula actually works fine because e^(2x) → 0,
giving (-1)/(1) = -1

### The Numerically Stable Solution
- For Positive x: Reformulation
```
tanh(x) = (e^(2x) - 1) / (e^(2x) + 1) = (e^(2x) + 1 - 2) / (e^(2x) + 1) = 1 - 2/(e^(2x) + 1)
```
-  For Negative x: Using Symmetry
```
tanh(-x) = (e^(-2x) - 1) / (e^(-2x) + 1) =  (2/(e^(-2x) + 1) - 1) = -1 × (1 - 2/(e^(2|x|) + 1))
```

### Unified formulation:
```
tanh(x) = sign(x) × (1 - 2/(e^(2|x|) + 1))
```
@whitneywhtsang whitneywhtsang self-assigned this Nov 14, 2025
@whitneywhtsang whitneywhtsang marked this pull request as ready for review November 15, 2025 17:23
@whitneywhtsang whitneywhtsang merged commit 4546255 into main Nov 15, 2025
23 checks passed
@whitneywhtsang whitneywhtsang deleted the whitneywhtsang/merge branch November 15, 2025 17:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants