Skip to content

Conversation

@whitneywhtsang
Copy link
Contributor

@whitneywhtsang whitneywhtsang commented Sep 30, 2024

This PR change the Triton base from 493f991 to e7ec3fe (Sept 26).
Pass rate: 98.99%

Please do not squash and merge this PR.

Moerafaat and others added 14 commits September 25, 2024 08:58
…4803)

triton-lang/triton#4492 started causing an
issue where chained MMAs on hopper would segfault with 8 warps. It seems
that previously this was checked, but the check got removed in this PR
and it's still unsupported.

Adding back this check means these MMAs will have to go back to shared
memory, but it's better than segfaulting until it's actually supported.

Resolves openxla/xla#17356

Co-authored-by: Tori <[email protected]>
Now we can have negative steps so use signed division.
There was a bug in shared memory base pointer calculation which misuses
`gep` build function and pass `ptr` type as shared memory's element type.
Fixed the bug.
LLVM AMDGPU backend supports special intrinsics
(https://llvm.org/docs/AMDGPUUsage.html#llvm-ir-intrinsics)
as hints to influence instruction scheduling. This PR
adds basic scaffolding for utilizing those intrinsics to better
control instructions generated from the backend. It is meant
to only target `tt.dot` operations which are often the
most intensive ones and may demand fine-tuning to achieve
better performance.

Facilities added here are experimental and we need to iterate
on it until to a good state.
This commit turns on the v2 pipeliner as the default.
We still keep v1 for some extended time to make
perf debugging easier; but expect to remove it soon.
Currently the reduction codegen unconditionally executes the combine
region which can create problems because we conditionally load from
shared memory, so this uses uninitialized registers.

Generally combine regions should be pure, so this shouldn't be
observable but with the overflow sanitizer the frontend injects
assertions into the combine region.

This changes the `accumulate` function to take a predicate and if the
combine region isn't speculateble we only run it on threads where the
predicate is true. In the common case, the codegen is unchanged.
…he key of config cache. (#4808)

The autotuner uses the index of the arguments of the Triton kernel
signature to look up the value to be used as the key for the config
cache.
There is an issue if the user pass the kernel arguments as keyword args
in arbitrary order.

The name of the argument should be used to look up the value of the args
passed by the user instead of the `key_idx`.
This prevents the autotuner from using a mismatched value as the key for
caching when the arguments are passed in arbitrary order as keyword
args.

- [X] I am not making a trivial change, such as fixing a typo in a
comment.

- [X] I have written a PR description following these
  [rules](https://cbea.ms/git-commit/#why-not-how).

- [X] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.

- Select one of the following.
  - [X] I have added tests.
    - `/test` for `lit` tests
    - `/unittest` for C++ tests
    - `/python/test` for end-to-end tests
  - [ ] This PR does not need a test because `FILL THIS IN`.

- Select one of the following.
  - [X] I have not added any `lit` tests.
- [ ] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
    and using the instructions it generates is not minimal.)
…persistent-matmul.py` (#4802)

- Format the introduction section in some tutorials.
- Add instructions for running the persistent matmul tutorial, as well
as instructions for using `proton-viewer`.
- Replace `torch.zeros` with `torch.empty` to remove unnecessary GPU
kernels.
- Add brackets `[` and `]` around shapes to improve the output
formatting.
- Remove redundant metric accumulation, as the Triton hook already
handles metric accumulation.
Summary: split loads so each group of uses with the same shared encoding
will have a corresponding load. This enables pipelining loads with
incompatible shared encoding.

AMD has its own version of assignMemoryLayouts, so the test case
load_two_users_incompatible_layouts will have different results for AMD.
For indirect loads, we try to assign them to later stages
```
unsigned stagesBetweenLoads =
   ceil<unsigned>(numStages - 2, maxIndirectionLevel + 1);
int stage = (maxIndirectionLevel - indLevel) * stagesBetweenLoads;
schedule.insert(loadOp, stage, loadsClusters[indLevel]);
```
If numStages is 2, there is no later stage to assign the indirect loads
to. The fix is to not pipeline the indirect loads.
We also generalize to not pipeline an indirect load if the indirection
level >= numStages - 1
…4774)

Triton LLVM codegen has a bug where local_loads from #shared to #mma
layout can lead to invalid code if the loaded shape is smaller than the
mma tile. Remove the workaround.

See triton-lang/triton#3561.

Verified that with test case: https://pastebin.com/xxP3cFmy (test.mlir),
running
triton-opt test.mlir -tritongpu-pipeline=num-stages=3
--convert-scf-to-cf --allocate-shared-memory
--convert-triton-gpu-to-llvm
has no issue.

Unit test case added in triton-lang/triton#4798
also shows no issue.
@whitneywhtsang whitneywhtsang self-assigned this Sep 30, 2024
@whitneywhtsang whitneywhtsang changed the title Merge OpenAI Triton commit 4348109 Merge OpenAI Triton commit e7ec3fe Sep 30, 2024
@whitneywhtsang whitneywhtsang marked this pull request as ready for review September 30, 2024 18:52
@whitneywhtsang whitneywhtsang merged commit 2a4b054 into main Sep 30, 2024
4 checks passed
@whitneywhtsang whitneywhtsang deleted the whitneywhtsang/merge branch September 30, 2024 19:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.