Skip to content

Conversation

@whitneywhtsang
Copy link
Contributor

This PR change the Triton base from c172d53 to 9f21c06 (Oct 27).
Pass rate: 93.92%->

matthias-springer and others added 30 commits October 22, 2025 13:37
This commit adds tests for `unrealized_conversion_cast` and make the
visitor in the axis analysis more robust.

The input AxisInfo should not be propagated if that would inject an
AxisInfo with an incorrect rank. That could cause a crash.

`unrealized_conversion_cast` ops typically appear during a dialect
conversion, but they are also useful for debugging / rapid prototyping
purposes. It allows programmers to hand-write the expected low-level IR
and connect it with high-level IR that will be lowered as usual. In such
a scenario, programmers write IR such as "Case 2" in the added test
case. Such IR currently crashes the axis analysis.

Also fix a crash when a function call with multi-dimensional function
argument is analyzed. (This is now triggered due to the improved
`unrealized_conversion_cast` handling.)
This PR add device-side TMA support for gluon. cc @ThomasRaoux
Use the correct condition value when other value exists.
We fix a number of cases where the constancy analysis could be improved.

The code is quite messy, and the whole pass could do with a full
rewrite, but we are not doing so ATM.

This PR was mostly vibecoded, with a cleaning pass afterwards from me.
…s. (#8512)

A few tests in tensor descriptor use "cuda" as device rather than a
'device'
fixture in the test arguments. This PR changes those tests to use
'device'
fixture instead so that third party users without a cuda runtime can run
on these tests.

<!---
The core Triton is a small number of people, and we receive many PRs
(thank
you!).  To help us review your code more quickly, **if you are a new
contributor (less than 3 PRs merged) we ask that you complete the
following
tasks and include the filled-out checklist in your PR description.**

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->

# New contributor declaration
- [x] I am not making a trivial change, such as fixing a typo in a
comment.

- [x] I have written a PR description following these
  [rules](https://cbea.ms/git-commit/#why-not-how).

- [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.

- Select one of the following.
  - [ ] I have added tests.
    - `/test` for `lit` tests
    - `/unittest` for C++ tests
    - `/python/test` for end-to-end tests
- [x] This PR does not need a test because it is editing the test file
only.

- Select one of the following.
  - [x] I have not added any `lit` tests.
- [ ] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
    and using the instructions it generates is not minimal.)

Co-authored-by: Micah Weston <[email protected]>
Add shared memory capacity for `gfx1250` which is 320 kbyte.
This RP fixes the layout and lowering for wmma scaled with small k dim
where the tensor's k dimension is smaller than the a single wmma scaled
instruction's k dimension. Add corresponding lit tests for common cases.
Prevent crash when lowering memdesc of pointer
…n (#8493)

TP > 1 is not supported in this mode
Adds support for lowering and in `UpdateAsyncWaitCnt`.

Note that on `gfx1250` async_loads use `asynccnt` which is separate from
register and TDM loads so they can finish out of order. This means
register and tdm loads should be ignored by `UpdateAsyncWaitCnt` and no
performance remark for the former compared to `GFX9`.

Intrinsics will be replaced by ROCDL ops once we bumped LLVM.
mainly 3 changes to `deduceTilesPerWarp`
    1) consider scale A and B vecSize together
    2) consider constant scale
    3) limit to block boundary

Bug fix of a pre-shuffled dot_scaled_a8w4 with activation fp8,
constant scale and weight mxfp4 with pres-huffled scale.
when block_m = 16 and the instr_size is [16, 16, 128],
tilesPerWarp=[2,2] will result in dummy mfma instructions.
Lowering is mostly the same as on GFX9 except that we support per lane
destination addresses so we do not need to swizzle the source pointers
and the LDS address is not a scalar anymore. Additionally we support
64bit per lane.
Note that we do not have (async) buffer load to lds on `gfx1250`.

The intrinsic will be replaced by a ROCDL op once it's ready in LLVM.
AxisInfo analysis currently retrieves the rank from any `ShapedType`
producing `PoisonOp`. This is a problem if the `PoisonOp` actually
produces a `MemDesc`, since the value produced by the `PoisonOp` may
flow into the same value as some other `MemDesc` producing operation,
which will have been assigned the "pessimistic state" and have rank 1.

When we attempt to join the two, the ranks will not match, potentially
resulting in a crash.

# New contributor declaration
- [x] I am not making a trivial change, such as fixing a typo in a
comment.

- [x] I have written a PR description following these
  [rules](https://cbea.ms/git-commit/#why-not-how).

- [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.

- Select one of the following.
  - [x] I have added tests.
    - `/test` for `lit` tests
    - `/unittest` for C++ tests
    - `/python/test` for end-to-end tests
  - [ ] This PR does not need a test because `FILL THIS IN`.

- Select one of the following.
  - [ ] I have not added any `lit` tests.
- [x] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
    and using the instructions it generates is not minimal.)
- Enable cp.async.bulk.tensor.2d.tile::gather4.shared on sm_120 and
sm_121.
- Skip TMA scatter4 test on sm_120 since it is unsupported by hardware.

Note:
All other TMA features except for cluster-related ones are supported on
sm_120.
This PR exposes the internal layout utility
`chooseScaledMfmaScaleLayout` and `chooseScaledWmmaScaleLayout` for
Gluon, to help generate a linear layout for scale used in
`mfma_scaled`/`wmma_scaled`. This also allows gluon kernels to specify a
scalar scale value or leave it as None.
Without resetting opt_flags, the following does not work and gives error
`AssertionError: opt_flags already set; please reset to None first`:

```
import torch
from triton_kernels.matmul_ogs import matmul_ogs, PrecisionConfig
from triton_kernels.matmul_ogs_details.opt_flags import (
    make_opt_flags,
    set_opt_flags,
)
from triton_kernels.routing import RoutingData

m = 64
n = 128
k = 32
BATCH_SIZE = 1000
dtype = torch.float16

x = torch.randn((BATCH_SIZE, m, k), device="cuda", dtype=dtype)
w = torch.randn((BATCH_SIZE, k, n), device="cuda", dtype=dtype)
bias = None

opt_flags = make_opt_flags(
    dtype,
    dtype,
    dtype,
    PrecisionConfig(),
    m,
    n,
    k,
    RoutingData(None, None, BATCH_SIZE, 1),
    True,
    False,
    False,
)

set_opt_flags(opt_flags)
tri_y = matmul_ogs(x, w, bias)

opt_flags.num_warps = 2
set_opt_flags(opt_flags)
tri_y = matmul_ogs(x, w, bias)
```

After adding `reset_opt_flags()` before the second call of
`set_opt_flags` everything works fine.
Functions and their individual arguments are passed as an array. All the
arguments are just appended together in MLIR, but the
`WarpSpecializeOp::canonicalize` method will clean up duplicate
arguments.
This is in preparation for more examples to add
and be consistent with other directory names.
…8531)

`warp_specialize` ops currently have unknown location set in the TTGIR
due to a quirk in the code emission in `_semantic.py`: for
`warp_specialize` we need save and then restore insert point. Location
is being inferred from the insert point, however if insert point happens
to be in a place that doesn't have location assigned (end of a block),
we set unknown loc. This change is a minimal fix that adds a helper that
gets the location from block's parent in such a case.
Alternatively we could also save location along with insert point, and
then restore it accordingly. This approach is simpler and should help
for most cases I could have think of however.
This change is important for consan changes I am working on, as it
breaks the LLVM backend if we create instrumentation function calls with
unknown location inferred from warp_specialize op.
…c (#8529)

During SWP, we are checking if a given `LoadOp` should be lowered to
`AsyncCopyGlobalToLocalOp` twice - first in `AssignLatency`, and
`LowerLoops` next. The two checks duplicate non-trivial conditions like
`copyVecBytes >= 4` or `op.getResultTypes()[0].getIntOrFloatBitWidth()
>= 32`.

I moved the `isPipeliningBeneficial` function from `AssignLatency` into
utilities so that it can also be used by `LowerLoops`. This will also be
used by WS to determine if `LoadOp` should be lowered to cpasync and
assigned to the load partition.
Expose `buffer_load` and `buffer_store`, inherited from CDNA3,
to gfx1250.
<!---
The core Triton is a small number of people, and we receive many PRs
(thank
you!).  To help us review your code more quickly, **if you are a new
contributor (less than 3 PRs merged) we ask that you complete the
following
tasks and include the filled-out checklist in your PR description.**

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->

# New contributor declaration
- [x] I am not making a trivial change, such as fixing a typo in a
comment.

- [x] I have written a PR description following these
  [rules](https://cbea.ms/git-commit/#why-not-how).

- [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.

- Select one of the following.
  - [x] I have added tests.
    - `/test` for `lit` tests
    - `/unittest` for C++ tests
    - `/python/test` for end-to-end tests
  - [ ] This PR does not need a test because `FILL THIS IN`.

- Select one of the following.
  - [x] I have not added any `lit` tests.
- [ ] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
    and using the instructions it generates is not minimal.)
…528)

Each aggregate class tracks its callable members and when the aggregate
is referenced by name, the cache keys of all its members are computed.
This does require `def __init__` to be marked as `@constexpr_function`
The heuristic for swapping was for 8-bit act I believe and this was
hurting perf of bf16 act

On GB200, M=8K x N ~= 1K and K ~= 5K, ~50 TF/s -> ~100 TF/s

<!---
The core Triton is a small number of people, and we receive many PRs
(thank
you!).  To help us review your code more quickly, **if you are a new
contributor (less than 3 PRs merged) we ask that you complete the
following
tasks and include the filled-out checklist in your PR description.**

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->

# New contributor declaration
- [x] I am not making a trivial change, such as fixing a typo in a
comment.

- [x] I have written a PR description following these
  [rules](https://cbea.ms/git-commit/#why-not-how).

- [ ] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.

- Select one of the following.
  - [ ] I have added tests.
    - `/test` for `lit` tests
    - `/unittest` for C++ tests
    - `/python/test` for end-to-end tests
  - [ ] This PR does not need a test because `FILL THIS IN`.

- Select one of the following.
  - [x] I have not added any `lit` tests.
- [ ] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
    and using the instructions it generates is not minimal.)
The `constexprs` field is never used outside `jit.py` and can be
computed on-demand from `self.params`. This addresses part of the
TODO(jlebar) at line 735.
When `TRITON_INTERPRET=1`, the cache should always be disabled (in both
cases, `cache_results=true` and `TRITON_CACHE_AUTOTUNING=1`). Otherwise,
there is an error in `check_disk_cache` since it is searching for a
`JITFunction` which is not available.
This has been discussed in #6678 already but the introduced changes
still enable caching results when `cache_results=true`.
…8542)

In particular, we generalise:
- `tl.trans(x)` from 2 dimensions to >= 2 dimensions, matching the
behaviour of numpy's `x.mT`;
- `tl.dot(A, B)` from 2 or 3 dimensions to >= 2 dimensions, supporting
multidimensional batches;
such that they batch over the first `n - 2` dimensions and operate on
the last 2 dimensions.

The generalised `tl.trans(x)` is useful for implementing the derivative
of `tl.dot`, as the same code can be written irrespective of the number
of batch dimensions:

    C = tl.dot(A, B)
    A_grad = tl.dot(C_grad, tl.trans(B))
    B_grad = tl.dot(tl.trans(A), C_grad)

Unit tests are included which test the new functionality of both
`tl.dot` and `tl.trans`.
ThomasRaoux and others added 4 commits October 27, 2025 08:54
The alloc_shape is not useful when doing memdesc_index, we should only
set it when slicing a memdesc
… matmul tutorial (#8553)

This makes it explicit that worker partitions run in additional warps,
above those specified in `num_warps`.
@whitneywhtsang whitneywhtsang self-assigned this Oct 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.