Skip to content

Commit 006191b

Browse files
Merge OpenAI Triton commit 7d18fd8 (#4755)
This PR change the Triton base from 2b5505c to 7d18fd8 (Jul 11). Pass rate: 98.46%
2 parents 6ac742f + 4aac0d3 commit 006191b

File tree

48 files changed

+2029
-479
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+2029
-479
lines changed

include/triton/Dialect/Gluon/Transforms/Passes.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,28 @@ def GluonResolveAutoEncodingsPass : Pass<"gluon-resolve-auto-encodings", "mlir::
1111

1212
}
1313

14+
def GluonCanonicalize: Pass<"gluon-canonicalize"> {
15+
let summary = "reduced set of simplifications for TTGIR";
16+
17+
let description = [{
18+
The `gluon-canonicalize` pass applies a reduced set of simplification
19+
and canonicalization patterns to the module.
20+
}];
21+
let dependentDialects = [
22+
"mlir::arith::ArithDialect",
23+
"mlir::cf::ControlFlowDialect",
24+
"mlir::scf::SCFDialect",
25+
];
26+
}
27+
28+
def GluonInline: Pass<"gluon-inline"> {
29+
let summary = "reduced set of simplifications for TTGIR";
30+
31+
let description = [{
32+
The `gluon-inline` pass applies a reduced set of simplification
33+
and canonicalization patterns to the module.
34+
}];
35+
let dependentDialects = [];
36+
}
37+
1438
#endif

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,7 +1109,8 @@ w2 w2 w3 w3
11091109
"unsigned":$MDim,
11101110
"unsigned":$NDim,
11111111
"bool":$isTransposed,
1112-
"CTALayoutAttr":$CTALayout
1112+
"CTALayoutAttr":$CTALayout,
1113+
DefaultValuedParameter<"std::optional<Type>", "FloatType::get($_ctxt, 32)">:$elementType
11131114
);
11141115

11151116
let builders = [
@@ -1118,9 +1119,11 @@ w2 w2 w3 w3
11181119
"unsigned":$MDim,
11191120
"unsigned":$NDim,
11201121
"bool":$isTransposed,
1121-
"CTALayoutAttr":$CTALayout), [{
1122+
"CTALayoutAttr":$CTALayout,
1123+
"std::optional<Type>":$elementType), [{
11221124
SmallVector<unsigned> tilesPerWarp(warpsPerCTA.size(), 1);
1123-
return $_get(context, version, warpsPerCTA, tilesPerWarp, MDim, NDim, isTransposed, CTALayout);
1125+
1126+
return $_get(context, version, warpsPerCTA, tilesPerWarp, MDim, NDim, isTransposed, CTALayout, elementType);
11241127
}]>
11251128
];
11261129

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -411,12 +411,6 @@ def TTG_GlobalScratchAllocOp : TTG_Op<"global_scratch_alloc"> {
411411
);
412412
let results = (outs Arg<TT_Ptr, "", [MemAlloc<GlobalMemory>]>:$result);
413413

414-
let builders = [
415-
OpBuilder<(ins "Type":$result, "int32_t":$nbytes, "int32_t":$alignment),
416-
[{ build($_builder, $_state, result,
417-
$_builder.getI32IntegerAttr(nbytes), $_builder.getI32IntegerAttr(alignment)); }]>
418-
];
419-
420414
let assemblyFormat = [{attr-dict `:` qualified(type($result))}];
421415
}
422416

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -360,18 +360,4 @@ def TritonGPUCoalesceAsyncCopy: Pass<"tritongpu-coalesce-async-copy", "mlir::Mod
360360
"mlir::triton::TritonDialect"];
361361
}
362362

363-
def TritonGPUCanonicalize: Pass<"tritongpu-canonicalize"> {
364-
let summary = "reduced set of simplifications for TTGIR";
365-
366-
let description = [{
367-
The `tritongpu-canonicalize` pass applies a reduced set of simplification
368-
and canonicalization patterns to the module.
369-
}];
370-
let dependentDialects = [
371-
"mlir::arith::ArithDialect",
372-
"mlir::cf::ControlFlowDialect",
373-
"mlir::scf::SCFDialect",
374-
];
375-
}
376-
377363
#endif

include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td

Lines changed: 114 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,32 @@ include "triton/Dialect/Triton/IR/TritonTypes.td"
77
include "mlir/IR/OpBase.td"
88
include "mlir/Interfaces/SideEffectInterfaces.td"
99

10+
//
11+
// Interfaces
12+
//
13+
def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">;
14+
15+
//
16+
// Ops
17+
//
18+
1019
class TTI_Op<string mnemonic, list<Trait> traits = []> :
1120
Op<TritonInstrument_Dialect, mnemonic, traits> {
1221
}
1322

14-
// Define an array of pointers to shared memory buffers
23+
def TTI_ExperimentalAssertInThreadOp : TTI_Op<"experimental_assert_in_thread", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
24+
let summary = "assert the condition within the current thread";
25+
let description = [{
26+
Assert that the condition is true given all the values are available in the current thread.
27+
If the condition is false, the message is printed, and the program is aborted.
28+
If check_any is true, any of the values in the condition must be true. Otherwise, all the
29+
values in the condition must be true.
30+
}];
31+
let arguments = (ins I1Tensor:$condition, StrAttr:$message, BoolAttr:$check_any);
32+
let assemblyFormat = "$condition `,` $message attr-dict `:` type($condition)";
33+
}
34+
35+
1536
def TTI_ExperimentalSharedBufferPointersOp : TTI_Op<"experimental_shared_buffer_pointers", [Pure]> {
1637
let summary = "definte an array of pointers to shared memory buffers";
1738
let description = [{
@@ -24,59 +45,119 @@ def TTI_ExperimentalSharedBufferPointersOp : TTI_Op<"experimental_shared_buffer_
2445
}];
2546
}
2647

27-
// Check if writing to a buffer guarded by a mbar is valid
28-
def TTI_ExperimentalCheckAsyncWriteWithMbarSharedOp : TTI_Op<"experimental_check_async_write_with_mbar_shared", [Pure]> {
29-
let summary = "check if writing to a buffer guarded by a mbar is valid";
48+
49+
def TTI_ExperimentalCheckOutstandingWritesOp : TTI_Op<"experimental_check_outstanding_writes", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
50+
let summary = "check if there are outstanding writes to a buffer guarded by a mbar";
3051
let description = [{
31-
Check if writing to a shared memory buffer guarded by a mbar is valid.
32-
Update the buffer state and assert if the buffer is being read or written.
52+
Check if there are outstanding writes to a buffer guarded by a mbar.
3353
}];
3454
let arguments = (ins
35-
TTG_MemDescType:$buffer,
36-
TTG_MemDescType:$mbar,
55+
TTG_MemDescType:$buf,
3756
TT_Tensor:$buffers,
38-
TT_Tensor:$states,
39-
TT_Tensor:$barriers
57+
TT_PtrLike:$writeBars,
58+
TypeAttr:$writeBarsType,
59+
Optional<I1>:$pred
4060
);
41-
let results = (outs
42-
TT_Tensor:$outStates,
43-
TT_Tensor:$outBarriers
61+
let assemblyFormat = [{
62+
$buf `{` $buffers `,` $writeBars `(` $writeBarsType `)` `}` (`,` $pred^)? attr-dict `:` type($buf) `,` type($buffers) `,` type($writeBars)
63+
}];
64+
let hasVerifier = 1;
65+
}
66+
67+
68+
def TTI_ExperimentalCheckOutstandingReadsOp : TTI_Op<"experimental_check_outstanding_reads", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
69+
let summary = "check if there are outstanding reads from a buffer guarded by a mbar";
70+
let description = [{
71+
Check if there are outstanding reads from a buffer guarded by a mbar.
72+
}];
73+
let arguments = (ins
74+
TTG_MemDescType:$buf,
75+
TT_Tensor:$buffers,
76+
TT_PtrLike:$readBars,
77+
TypeAttr:$readBarsType,
78+
Optional<I1>:$pred
4479
);
4580
let assemblyFormat = [{
46-
$buffer `,` $mbar `{` $buffers `,` $states `,` $barriers `}` attr-dict `:` type($buffer) `,` type($mbar) `,` type($buffers) `,` type($states) `,` type($barriers) `->` type($outStates) `,` type($outBarriers)
81+
$buf `{` $buffers `,` $readBars `(` $readBarsType `)` `}` (`,` $pred^)? attr-dict `:` type($buf) `,` type($buffers) `,` type($readBars)
4782
}];
48-
let builders = [
49-
OpBuilder<(ins "Value":$buffer, "Value":$mbar, "Value":$buffers, "Value":$states, "Value":$barriers),[{
50-
build($_builder, $_state, {states.getType(), barriers.getType()}, buffer, mbar, buffers, states, barriers);
51-
}]>
52-
];
83+
let hasVerifier = 1;
5384
}
5485

55-
def TTI_ExperimentalCheckWaitMbarOp : TTI_Op<"experimental_check_wait_mbar", [Pure]> {
56-
let summary = "check if waiting on a mbar is valid and update the barrier state";
86+
87+
def TTI_ExperimentalMarkAsWriteOp : TTI_Op<"experimental_mark_as_write", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
88+
let summary = "mark a buffer as being written to using mbar as a guard";
5789
let description = [{
58-
Check if waiting on a mbar is valid and update the barrier state.
90+
Mark a buffer as being written to using mbar as a guard.
5991
}];
6092
let arguments = (ins
93+
TTG_MemDescType:$buf,
6194
TTG_MemDescType:$mbar,
95+
TT_Tensor:$buffers,
96+
TT_PtrLike:$writeBars,
97+
TypeAttr:$writeBarsType,
98+
Optional<I1>:$pred
99+
);
100+
let assemblyFormat = [{
101+
$buf `,` $mbar `{` $buffers `,` $writeBars `(` $writeBarsType `)` `}` (`,` $pred^)? attr-dict `:` type($buf) `,` type($mbar) `,` type($buffers) `,` type($writeBars)
102+
}];
103+
let hasVerifier = 1;
104+
}
105+
106+
107+
def TTI_ExperimentalMarkAsReadOp : TTI_Op<"experimental_mark_as_read", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
108+
let summary = "mark a buffer as being read from using mbar as a guard";
109+
let description = [{
110+
Mark a buffer as being read from using mbar as a guard.
111+
}];
112+
let arguments = (ins
113+
TTG_MemDescType:$buf,
114+
TTG_MemDescType:$mbar,
115+
TT_Tensor:$buffers,
62116
TT_Tensor:$barriers,
63-
TT_Tensor:$states
64-
);
117+
TT_PtrLike:$readBars,
118+
TypeAttr:$readBarsType,
119+
Optional<I1>:$pred
120+
);
121+
let assemblyFormat = [{
122+
$buf `,` $mbar `{` $buffers `,` $barriers `,` $readBars `(` $readBarsType `)` `}` (`,` $pred^)? attr-dict `:` type($buf) `,` type($mbar) `,` type($buffers) `,` type($barriers) `,` type($readBars)
123+
}];
124+
let hasVerifier = 1;
125+
}
65126

66-
let results = (outs
67-
TT_Tensor:$outStates,
68-
TT_Tensor:$outBarriers);
69127

128+
def TTI_ExperimentalClearWriteBarrierOp : TTI_Op<"experimental_clear_write_barrier", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
129+
let summary = "clear the write state for buffers being guarded by an mbar";
130+
let description = [{
131+
Clear the write state for buffers being guarded by an mbar.
132+
}];
133+
let arguments = (ins
134+
TTG_MemDescType:$mbar,
135+
TT_PtrLike:$writeBars,
136+
TypeAttr:$writeBarsType,
137+
Optional<I1>:$pred
138+
);
70139
let assemblyFormat = [{
71-
$mbar `{` $states `,` $barriers `}` attr-dict `:` type($mbar) `,` type($states) `,` type($barriers) `->` type($outStates) `,` type($outBarriers)
140+
$mbar `{` $writeBars `(` $writeBarsType `)` `}` (`,` $pred^)? attr-dict `:` type($mbar) `,` type($writeBars)
72141
}];
142+
}
73143

74-
let builders = [
75-
OpBuilder<(ins "Value":$mbar, "Value":$barriers, "Value":$states),
76-
[{
77-
build($_builder, $_state, {states.getType(), barriers.getType()}, mbar, barriers, states);
78-
}]>];
79144

145+
def TTI_ExperimentalClearReadBarrierOp : TTI_Op<"experimental_clear_read_barrier", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
146+
let summary = "clear the read state for buffers being guarded by an mbar";
147+
let description = [{
148+
Clear the read state for buffers being guarded by an mbar.
149+
}];
150+
let arguments = (ins
151+
TTG_MemDescType:$mbar,
152+
TT_Tensor:$barriers,
153+
TT_PtrLike:$readBars,
154+
TypeAttr:$readBarsType,
155+
Optional<I1>:$pred
156+
);
157+
let assemblyFormat = [{
158+
$mbar `{` $barriers `,` $readBars `(` $readBarsType `)` `}` (`,` $pred^)? attr-dict `:` type($mbar) `,` type($barriers) `,` type($readBars)
159+
}];
160+
let hasVerifier = 1;
80161
}
81162

82163
#endif // TRITONINSTRUMENT_OPS
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#include "triton/Dialect/TritonInstrument/IR/Dialect.h"
2+
3+
namespace mlir::triton::instrument {
4+
5+
Operation *createStoreScratchMemory(OpBuilder &b, Location loc, Value alloc,
6+
Value tensor, RankedTensorType tensorType);
7+
Operation *createLoadScratchMemory(OpBuilder &b, Location loc, Value alloc,
8+
RankedTensorType tensorType);
9+
Value expandOuterSlicedDim(OpBuilder &b, Location loc, Value tensor);
10+
TypedValue<RankedTensorType> createConstIntTensor(OpBuilder &builder,
11+
Location loc, int val,
12+
RankedTensorType tensorType);
13+
14+
} // namespace mlir::triton::instrument

0 commit comments

Comments
 (0)