Skip to content

Commit 1f52259

Browse files
Merge OpenAI Triton commit 65d9862 (#4534)
This PR change the Triton base from 993c8da to 65d9862 (Jun 16). Pass rate: 97.12%
2 parents 6c50b82 + dda8a43 commit 1f52259

File tree

21 files changed

+205
-59
lines changed

21 files changed

+205
-59
lines changed

.github/workflows/integration-tests-amd.yml

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ jobs:
1414
strategy:
1515
matrix:
1616
runner: ${{ fromJson(inputs.matrix) }}
17+
include:
18+
- image: rocm/pytorch:rocm6.2.2_ubuntu22.04_py3.10_pytorch_2.5.1_asan
19+
- image: rocm/7.0-preview:rocm7.0_preview_ubuntu22.04_llama2_70b_training_mlperf_mi35X_prealpha
20+
runner: ["amd-gfx950"]
1721
env:
1822
RUNNER_TYPE: ${{ matrix.runner[1] }}
1923
TRITON_BUILD_WITH_CCACHE: "true"
@@ -24,7 +28,7 @@ jobs:
2428
PYTHON: "python3"
2529
CCACHE_COMPRESS: "true"
2630
container:
27-
image: rocm/pytorch:rocm6.2.2_ubuntu22.04_py3.10_pytorch_2.5.1_asan
31+
image: ${{ matrix.image }}
2832
options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root
2933
steps:
3034
- name: Checkout
@@ -76,9 +80,14 @@ jobs:
7680
run: |
7781
echo "PATH is '$PATH'"
7882
pip uninstall -y triton pytorch-triton-rocm
79-
ccache --zero-stats
83+
84+
if [ "${{ matrix.runner[0] }}" != "amd-gfx950" ]; then
85+
ccache --zero-stats
86+
fi
87+
8088
make dev-install
8189
- name: CCache Stats
90+
if: ${{ matrix.runner[0] != 'amd-gfx950' }}
8291
run: ccache --print-stats
8392
- name: Run lit tests
8493
run: make test-lit
@@ -101,7 +110,12 @@ jobs:
101110
pytest --capture=tee-sys -rfs -vvv instrumentation/test_gpuhello.py
102111
103112
# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
104-
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s -n 8 language/test_line_info.py
113+
if [ "${{ matrix.runner[0] }}" = "amd-gfx950" ]; then
114+
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s -n 8 language/test_line_info.py -k "not test_line_info_ir_source"
115+
else
116+
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s -n 8 language/test_line_info.py
117+
fi
118+
105119
- name: Run asan tests on AMD
106120
if: false
107121
run: |
@@ -121,8 +135,12 @@ jobs:
121135
cd python/test/regression
122136
python3 -m pytest -s -n 8 ./test_cast_matmul.py
123137
- name: Run Proton tests
124-
if: ${{ matrix.runner[0] != 'nvidia-gb200' }}
125-
run: make test-proton
138+
run: |
139+
if [ "${{ matrix.runner[0] }}" = "amd-gfx950" ]; then
140+
python3 -m pytest -s -n 8 third_party/proton/test -k "not test_instrument_exec"
141+
else
142+
make test-proton
143+
fi
126144
- name: Run C++ unittests
127145
run: make test-cpp
128146
- name: Inspect cache directories

.github/workflows/runner-preparation.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ jobs:
9696
run: |
9797
if [ x"${{ github.repository }}" == x"triton-lang/triton" ]; then
9898
echo '::set-output name=matrix-NVIDIA::[["nvidia-a100"], ["nvidia-h100"], ["nvidia-gb200"]]'
99-
echo '::set-output name=matrix-AMD::[["self-hosted", "gfx90a"], ["amd-gfx942"]]'
99+
echo '::set-output name=matrix-AMD::[["self-hosted", "gfx90a"], ["amd-gfx942"], ["amd-gfx950"]]'
100100
echo '::set-output name=matrix-MACOS::[["macos-latest"]]'
101101
else
102102
echo '::set-output name=matrix-NVIDIA::["ubuntu-latest"]'
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#ifndef TRITON_DIALECT_TRITON_IR_DISCARDABLE_ATTRIBUTES_H_
2+
#define TRITON_DIALECT_TRITON_IR_DISCARDABLE_ATTRIBUTES_H_
3+
4+
#include "mlir/Support/LLVM.h"
5+
#include "triton/Dialect/Triton/IR/Dialect.h"
6+
7+
namespace mlir::triton {
8+
9+
// Filter out attributes from the given operation that are not present in
10+
// the allowList.
11+
[[nodiscard]] SmallVector<NamedAttribute>
12+
filterDiscardableAttrs(Operation *op, ArrayRef<StringRef> allowList);
13+
14+
} // namespace mlir::triton
15+
#endif // TRITON_DIALECT_TRITON_IR_DISCARDABLE_ATTRIBUTES_H_

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ bool isPureUnaryInlineAsm(Operation *op);
208208
int getNVIDIAComputeCapability(Operation *module);
209209

210210
// Read the amd target from the module attributes
211-
StringRef getAMDArch(Operation *module);
211+
std::optional<StringRef> getAMDArch(Operation *module);
212212

213213
std::optional<mlir::triton::gpu::SwizzledSharedEncodingAttr>
214214
getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible);

lib/Dialect/Triton/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_public_tablegen_target(TritonCanonicalizeIncGen)
44

55
add_triton_library(TritonIR
66
Dialect.cpp
7+
DiscardableAttributes.cpp
78
Ops.cpp
89
Traits.cpp
910
Types.cpp
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#include "mlir/Support/LLVM.h"
2+
#include "triton/Dialect/Triton/IR/Dialect.h"
3+
4+
namespace mlir::triton {
5+
6+
SmallVector<NamedAttribute>
7+
filterDiscardableAttrs(Operation *op, ArrayRef<StringRef> allowList) {
8+
SmallVector<NamedAttribute> propagatedAttrs;
9+
for (auto attrName : allowList) {
10+
Attribute attr = op->getDiscardableAttr(attrName);
11+
if (attr)
12+
propagatedAttrs.emplace_back(attrName, attr);
13+
}
14+
return propagatedAttrs;
15+
}
16+
17+
} // namespace mlir::triton

lib/Dialect/Triton/Transforms/Combine.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "mlir/Support/LogicalResult.h"
77
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
88
#include "triton/Dialect/Triton/IR/Dialect.h"
9+
#include "triton/Dialect/Triton/IR/DiscardableAttributes.h"
910
#include "triton/Dialect/Triton/Transforms/Passes.h"
1011

1112
namespace mlir::triton {

lib/Dialect/Triton/Transforms/Combine.td

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,15 @@ def CombineDotAddFRevPattern : Pat<
3939
// Note: leave (sub %c0, %c0) canceling to ArithDialect
4040
// (ref: ArithCanonicalization.td)
4141
defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
42+
43+
def CopyDiscardableAttrs: NativeCodeCallVoid<
44+
"$1.getOwner()->setDiscardableAttrs(triton::filterDiscardableAttrs($0.getOwner(), "
45+
"{\"tt.divisibility\", \"tt.contiguity\", \"tt.constancy\"}))">;
46+
4247
def CombineAddPtrPattern : Pat<
43-
(TT_AddPtrOp (TT_AddPtrOp $ptr, $idx0), $idx1),
44-
(TT_AddPtrOp $ptr, (Arith_AddIOp $idx0, $idx1, DefOverflow)),
45-
[(Constraint<CPred<"isAddPtrOffsetCombinable($0, $1)">> $idx0, $idx1)]>;
48+
(TT_AddPtrOp:$src (TT_AddPtrOp $ptr, $idx0), $idx1),
49+
(TT_AddPtrOp:$dest $ptr, (Arith_AddIOp $idx0, $idx1, DefOverflow)),
50+
[(Constraint<CPred<"isAddPtrOffsetCombinable($0, $1)">> $idx0, $idx1)],
51+
[(CopyDiscardableAttrs $src, $dest)]>;
4652

4753
#endif

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,14 +1055,19 @@ int getNVIDIAComputeCapability(Operation *module) {
10551055
return computeCapability;
10561056
}
10571057

1058-
StringRef getAMDArch(Operation *module) {
1058+
std::optional<StringRef> getAMDArch(Operation *module) {
10591059
StringAttr targetAttr =
10601060
module->getAttrOfType<StringAttr>(triton::gpu::AttrTargetName);
1061-
assert(targetAttr && "Expected a target attribute on the module operation");
1061+
if (!targetAttr) {
1062+
LDBG("Expected a target attribute on the module operation");
1063+
return {};
1064+
}
10621065

10631066
StringRef ref = targetAttr.strref();
1064-
assert(ref.starts_with("hip:") &&
1065-
"expected target attribute to be prefixed with \"hip:\"");
1067+
if (!ref.starts_with("hip:")) {
1068+
LDBG("expected target attribute to be prefixed with \"hip:\"");
1069+
return {};
1070+
}
10661071

10671072
return ref.drop_front(4); // drop the "hip:"
10681073
}

test/Triton/combine.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,29 @@ tt.func @test_combine_addptr_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f3
8484
tt.return %ptr1 : tensor<8x!tt.ptr<f32>>
8585
}
8686

87+
// CHECK-LABEL: @test_combine_addptr_pattern_discardableattrs
88+
tt.func @test_combine_addptr_pattern_discardableattrs(%base: !tt.ptr<f32>) -> !tt.ptr<f32> {
89+
%off0 = arith.constant 8 : i32
90+
%off1 = arith.constant 4 : i32
91+
// CHECK-NEXT: %[[cst:.*]] = arith.constant 12 : i32
92+
// CHECK-NEXT: %0 = tt.addptr %{{.*}}, %[[cst]] {tt.constancy = 8 : i32, tt.contiguity = 512 : i32, tt.divisibility = 16 : i32} : !tt.ptr<f32>, i32
93+
%ptr0 = tt.addptr %base, %off0 : !tt.ptr<f32>, i32
94+
%ptr1 = tt.addptr %ptr0, %off1 {tt.divisibility = 16 : i32, tt.constancy = 8 : i32, tt.contiguity = 512 : i32} : !tt.ptr<f32>, i32
95+
96+
tt.return %ptr1 : !tt.ptr<f32>
97+
}
98+
99+
// CHECK-LABEL: @test_combine_addptr_pattern_discardableattrs_disallowed
100+
tt.func @test_combine_addptr_pattern_discardableattrs_disallowed(%base: !tt.ptr<f32>) -> !tt.ptr<f32> {
101+
%off0 = arith.constant 8 : i32
102+
%off1 = arith.constant 4 : i32
103+
// CHECK-NEXT: %[[cst:.*]] = arith.constant 12 : i32
104+
// CHECK-NEXT: %0 = tt.addptr %{{.*}}, %[[cst]] {tt.divisibility = 16 : i32} : !tt.ptr<f32>, i32
105+
%ptr0 = tt.addptr %base, %off0 : !tt.ptr<f32>, i32
106+
%ptr1 = tt.addptr %ptr0, %off1 {tt.divisibility = 16 : i32, tt.disallowed = 8 : i32} : !tt.ptr<f32>, i32
107+
108+
tt.return %ptr1 : !tt.ptr<f32>
109+
}
87110
// CHECK-LABEL: @test_combine_addptr_pattern_i64
88111
tt.func @test_combine_addptr_pattern_i64(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> {
89112
%off0 = arith.constant 10 : i64

0 commit comments

Comments
 (0)