Skip to content

Commit b58720a

Browse files
Merge OpenAI Triton commit 690f690 (#4914)
This PR change the Triton base from 4bcdbde to 690f690 (Aug 1). Pass rate: 98.86%->98.85% (#4916)
2 parents 0ab03be + 76f767f commit b58720a

File tree

148 files changed

+3374
-1359
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

148 files changed

+3374
-1359
lines changed

.github/CODEOWNERS

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,13 @@ lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp @ptillet
4646
# third_party
4747
# -----------
4848
third_party/amd/ @antiagainst @zhanglx13
49+
50+
# -----------
51+
# gluon
52+
# -----------
53+
python/triton/experimental/gluon/ @peterbell10
54+
python/src/gluon_ir.cc @peterbell10
55+
python/test/gluon @peterbell10
56+
test/Gluon @peterbell10
57+
include/triton/Dialect/Gluon @peterbell10
58+
lib/Dialect/Gluon @peterbell10

.github/workflows/integration-tests-amd.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,10 @@ jobs:
117117
if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then
118118
echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
119119
fi
120+
121+
# Test gluon
122+
pytest --capture=tee-sys -rfs -n 8 python/test/gluon/
123+
120124
pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py
121125
pytest --capture=tee-sys -rfs third_party/amd/python/test/test_extract_slice_concat_op.py
122126
TRITON_ALWAYS_COMPILE=1 pytest --capture=tee-sys -rfs third_party/amd/python/test/test_scalarize_packed_fops.py

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "triton/Conversion/MLIRTypes.h"
55

66
namespace mlir::triton {
7+
enum class ProgramIDDim : uint32_t;
78

89
class TargetInfoBase {
910
public:
@@ -48,7 +49,7 @@ class TargetInfoBase {
4849
Value i) const = 0;
4950

5051
virtual Value programId(RewriterBase &rewriter, Location loc,
51-
ModuleOp moduleOp, int axis) const = 0;
52+
ModuleOp moduleOp, ProgramIDDim axis) const = 0;
5253

5354
virtual bool warpReduce(RewriterBase &rewriter, Location loc,
5455
SmallVector<Value> &acc, triton::ReduceOp op,

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,9 @@ inline bool isCanonicalIndex(unsigned index, unsigned freeVarMask) {
614614
// group code isolated from above by invoking this function.
615615
void makeAllWarpGroupsIsolatedFromAbove(Operation *op);
616616

617+
// Set the correct loop annotation on LLVM branch ops.
618+
void fixUpLoopAnnotation(ModuleOp mod);
619+
617620
/// Converts ConverLayoutOp to llvm using padded pattern.
618621
/// This pattern adds unused memory locations after every rows of tensor fastest
619622
/// changing dimension:

include/triton/Dialect/Triton/Transforms/Passes.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,4 +90,15 @@ def TritonLoopAwareCSE : Pass<"triton-loop-aware-cse", "mlir::ModuleOp"> {
9090
}];
9191
}
9292

93+
def TritonSCFToCF : Pass</*cli-arg*/"triton-scf-to-cf", /*Op*/"mlir::ModuleOp"> {
94+
let summary = "MLIR's SCF To CF plus some extra attributes propagation.";
95+
let description = [{
96+
This pass uses MLIR's SCF To CF pass as base. Additionally, it propagates
97+
some extra attributes to the converted CFG.
98+
TODO: upstream the llvm loop attribute propagation and remove this pass.
99+
}];
100+
101+
let dependentDialects = [];
102+
}
103+
93104
#endif

include/triton/Dialect/TritonInstrument/IR/TritonInstrumentAttrDefs.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ include "mlir/IR/EnumAttr.td"
66
def TT_MemTypeAttr : I32EnumAttr<
77
"MemType", "",
88
[
9-
I32EnumAttrCase<"SHARED", 0, "shared">,
10-
I32EnumAttrCase<"TENSOR", 1, "tensor">,
9+
I32EnumAttrCase<"SHARED_MEM", 0, "shared_mem">,
10+
I32EnumAttrCase<"TENSOR_MEM", 1, "tensor_mem">,
1111
]> {
1212
let cppNamespace = "::mlir::triton::instrument";
1313
}

include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td

Lines changed: 33 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def TTI_ExperimentalBufferPointersOp : TTI_Op<"experimental_buffer_pointers", [P
4747
}
4848

4949

50-
def TTI_ExperimentalCheckOutstandingWritesOp : TTI_Op<"experimental_check_outstanding_writes", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
50+
def TTI_ExperimentalCheckWriteStateOp : TTI_Op<"experimental_check_write_state", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
5151
let summary = "check if there are outstanding writes to a buffer guarded by a mbar";
5252
let description = [{
5353
Check if the writeState tensor has non-zero value associated with the buffer.
@@ -76,7 +76,7 @@ def TTI_ExperimentalCheckOutstandingWritesOp : TTI_Op<"experimental_check_outsta
7676
}
7777

7878

79-
def TTI_ExperimentalCheckOutstandingReadsOp : TTI_Op<"experimental_check_outstanding_reads", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
79+
def TTI_ExperimentalCheckReadBarriersOp : TTI_Op<"experimental_check_read_barriers", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
8080
let summary = "check if there are outstanding reads from a buffer guarded by a mbar";
8181
let description = [{
8282
Check if there are outstanding reads from a buffer guarded by a mbar.
@@ -95,8 +95,8 @@ def TTI_ExperimentalCheckOutstandingReadsOp : TTI_Op<"experimental_check_outstan
9595
}
9696

9797

98-
def TTI_ExperimentalMarkAsWriteOp : TTI_Op<"experimental_mark_as_write", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
99-
let summary = "mark a buffer as being written to using mbar as a guard";
98+
def TTI_ExperimentalSetWriteStateOp : TTI_Op<"experimental_set_write_state", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
99+
let summary = "mark a buffer as being written in writeState tensor";
100100
let description = [{
101101
Mark a buffer as being written to. It is not yet tracked by a barrier, until
102102
`commit_write_with_barrier` is called, at which point all the buffers being written
@@ -146,7 +146,7 @@ def TTI_ExperimentalCommitWriteWithBarrierOp : TTI_Op<"experimental_commit_write
146146
}
147147

148148

149-
def TTI_ExperimentalMarkAsReadOp : TTI_Op<"experimental_mark_as_read", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
149+
def TTI_ExperimentalSetReadBarrierOp : TTI_Op<"experimental_set_read_barrier", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
150150
let summary = "mark a buffer as being read from using mbar as a guard";
151151
let description = [{
152152
Mark a buffer as being read from using mbar as a guard.
@@ -226,72 +226,70 @@ def TTI_ExperimentalCheckBarrierWritesClearedOp : TTI_Op<"experimental_check_bar
226226
}
227227

228228

229-
// TODO: Potentially resolve the naming/functionality clash with commit_write_with_barrier
230-
def TTI_ExperimentalStageWriteForCommitOp : TTI_Op<"experimental_stage_write_for_commit", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
231-
let summary = "Preapre to an async copy of a buffer. Staged until commit_group is called.";
229+
def TTI_ExperimentalStageAccessForCommitOp : TTI_Op<"experimental_stage_access_for_commit", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
230+
let summary = "";
232231
let description = [{
233-
Preapre to an async copy of a buffer. Staged until commit_group is called. The implementation will write `-1` to the
234-
`write_commits` tensor under the indices corresponding to the buffer.
232+
For operations that use `outstanding` to track the number of outstanding commits (rather than mbarriers),
233+
mark the buffer as being accessed, but not commited yet, by marking it with `-1`.
235234
}];
236235
let arguments = (ins
237236
TTG_MemDescType:$buf,
238237
TT_Tensor:$buffers,
239-
TT_PtrLike:$writeCommits,
240-
TypeAttr:$writeCommitsType,
238+
TT_PtrLike:$outstandingCommits,
239+
TypeAttr:$outstandingCommitsType,
241240
Optional<I1>:$pred
242241
);
243242
let assemblyFormat = [{
244-
$buf `{` $buffers `,` $writeCommits `(` $writeCommitsType `)` `}` (`,` $pred^)? attr-dict `:` type($buf) `,` type($buffers) `,` type($writeCommits)
243+
$buf `{` $buffers `,` $outstandingCommits `(` $outstandingCommitsType `)` `}` (`,` $pred^)? attr-dict `:` type($buf) `,` type($buffers) `,` type($outstandingCommits)
245244
}];
246-
// let hasVerifier = 1;
245+
let hasVerifier = 1;
247246
}
248247

249-
def TTI_ExperimentalCommitWritesOp : TTI_Op<"experimental_commit_writes", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
250-
let summary = "Commit all the staged writes for all the buffers.";
248+
def TTI_ExperimentalCommitAccessesOp : TTI_Op<"experimental_commit_accesses", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
249+
let summary = "Commit all the staged accesses for all the buffers.";
251250
let description = [{
252-
Commit all the staged writes for all the buffers.
251+
Commit all the staged accesses for all the buffers.
253252
}];
254253
let arguments = (ins
255-
TT_PtrLike:$writeCommits,
256-
TypeAttr:$writeCommitsType,
254+
TT_PtrLike:$outstandingCommits,
255+
TypeAttr:$outstandingCommitsType,
257256
Optional<I1>:$pred);
258257
let assemblyFormat = [{
259-
`{` $writeCommits `(` $writeCommitsType `)` `}` (`,` $pred^)? attr-dict `:` type($writeCommits)
258+
`{` $outstandingCommits `(` $outstandingCommitsType `)` `}` (`,` $pred^)? attr-dict `:` type($outstandingCommits)
260259
}];
261-
// let hasVerifier = 1;
262260
}
263261

264-
def TTI_ExperimentalClearWriteCommitsOp : TTI_Op<"experimental_clear_write_commits", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
265-
let summary = "Clear all the write commits more distant than `outstandingNum.";
262+
def TTI_ExperimentalClearOutstandingCommitsOp : TTI_Op<"experimental_clear_outstanding_commits", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
263+
let summary = "Clear all the outstanding commits more distant than `outstandingNum.";
266264
let description = [{
267-
Clear all the write commits more distant than `outstandingNum` from the current thread.
265+
Clear all the outstanding commits more distant than `outstandingNum` from the current thread.
268266
}];
269267
let arguments = (ins
270-
TT_PtrLike:$writeCommits,
271-
TypeAttr:$writeCommitsType,
268+
TT_PtrLike:$outstandingCommits,
269+
TypeAttr:$outstandingCommitsType,
272270
I32Attr:$outstandingNum,
273271
Optional<I1>:$pred);
274272
let assemblyFormat = [{
275-
`{` $writeCommits `(` $writeCommitsType `)` `}` `,` $outstandingNum (`,` $pred^)? attr-dict `:` type($writeCommits)
273+
`{` $outstandingCommits `(` $outstandingCommitsType `)` `}` `,` $outstandingNum (`,` $pred^)? attr-dict `:` type($outstandingCommits)
276274
}];
277-
// let hasVerifier = 1;
278275
}
279276

280-
def TTI_ExperimentalCheckWriteCommitOp : TTI_Op<"experimental_check_write_commit", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
281-
let summary = "Check if the buffer has an outstanding write commit.";
277+
def TTI_ExperimentalCheckOutstandingCommitsOp : TTI_Op<"experimental_check_outstanding_commits", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
278+
let summary = "Check if the buffer has an outstanding commit.";
282279
let description = [{
283-
Check if the buffer has an outstanding write commit.
280+
Check if the buffer has an outstanding commit.
284281
}];
285282
let arguments = (ins
286283
TTG_MemDescType:$buf,
287284
TT_Tensor:$buffers,
288-
TT_PtrLike:$writeCommits,
289-
TypeAttr:$writeCommitsType,
285+
TT_PtrLike:$outstandingCommits,
286+
TypeAttr:$outstandingCommitsType,
287+
StrAttr:$pendingAccessType,
290288
Optional<I1>:$pred);
291289
let assemblyFormat = [{
292-
$buf `{` $buffers `,` $writeCommits `(` $writeCommitsType `)` `}` (`,` $pred^)? attr-dict `:` type($buf) `,` type($buffers) `,` type($writeCommits)
290+
$buf `{` $buffers `,` $outstandingCommits `(` $outstandingCommitsType `)` `}` `,` $pendingAccessType (`,` $pred^)? attr-dict `:` type($buf) `,` type($buffers) `,` type($outstandingCommits)
293291
}];
294-
// let hasVerifier = 1;
292+
let hasVerifier = 1;
295293
}
296294

297295
#endif // TRITONINSTRUMENT_OPS

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def TTNG_WarpGroupDotWaitOp : TTNG_Op<"warp_group_dot_wait", [DeclareOpInterface
123123
}];
124124

125125
let assemblyFormat = "$inputs attr-dict `:` type($inputs)";
126+
let hasVerifier = 1;
126127
}
127128

128129
def TTNG_InitBarrierOp : TTNG_Op<"init_barrier"> {

lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
6363
while (auto callLoc = dyn_cast<CallSiteLoc>(loc))
6464
loc = callLoc.getCallee();
6565

66+
while (auto nameLoc = dyn_cast<NameLoc>(loc))
67+
loc = nameLoc.getChildLoc();
68+
6669
if (auto fileLineColLoc = dyn_cast<FileLineColLoc>(loc)) {
6770
file = fileLineColLoc.getFilename();
6871
line = fileLineColLoc.getLine();

lib/Conversion/TritonGPUToLLVM/PrintOpToLLVM.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,10 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern<triton::PrintOp> {
2626
ConversionPatternRewriter &rewriter) const override {
2727
auto loc = op->getLoc();
2828

29-
auto getPid = [&](int axis) {
30-
return targetInfo.programId(rewriter, loc,
31-
op->getParentOfType<ModuleOp>(), axis);
32-
};
33-
std::array<Value, 3> pid = {getPid(0), getPid(1), getPid(2)};
29+
std::array<Value, 3> pid;
30+
auto module = op->getParentOfType<ModuleOp>();
31+
for (auto axis : {ProgramIDDim::X, ProgramIDDim::Y, ProgramIDDim::Z})
32+
pid[(int)axis] = targetInfo.programId(rewriter, loc, module, axis);
3433

3534
// Simple printf of a string without any tensors.
3635
if (op.getNumOperands() == 0) {

0 commit comments

Comments
 (0)