Skip to content

Commit d8bde0a

Browse files
Merge OpenAI Triton commit 6c3d943 (#4583)
This PR change the Triton base from 40f7163 to 6c3d943 (Jun 20). Pass rate: 97.12%
2 parents d4400d2 + b6fe9be commit d8bde0a

File tree

65 files changed

+2237
-447
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+2237
-447
lines changed

bin/RegisterTritonDialects.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h"
1818
#include "triton/Dialect/Triton/IR/Dialect.h"
1919
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
20+
#include "triton/Dialect/TritonInstrument/IR/Dialect.h"
2021
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
2122

2223
// Below headers will allow registration to ROCm passes
@@ -26,6 +27,7 @@
2627

2728
#include "triton/Dialect/Triton/Transforms/Passes.h"
2829
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
30+
#include "triton/Dialect/TritonInstrument/Transforms/Passes.h"
2931
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
3032

3133
#include "nvidia/hopper/include/Transforms/Passes.h"
@@ -64,6 +66,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
6466
mlir::triton::gpu::registerTritonGPUPasses();
6567
mlir::triton::nvidia_gpu::registerTritonNvidiaGPUPasses();
6668
mlir::test::intel::registerTestAxisInfoPass();
69+
mlir::triton::instrument::registerTritonInstrumentPasses();
6770
mlir::test::registerTestAliasPass();
6871
mlir::test::registerTestAlignmentPass();
6972
mlir::test::registerAMDTestAlignmentPass();
@@ -123,9 +126,10 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
123126
registry.insert<
124127
mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
125128
mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect,
126-
mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect,
127-
mlir::arith::ArithDialect, mlir::scf::SCFDialect, mlir::gpu::GPUDialect,
128-
mlir::LLVM::LLVMDialect, mlir::NVVM::NVVMDialect,
129+
mlir::triton::gpu::TritonGPUDialect,
130+
mlir::triton::instrument::TritonInstrumentDialect,
131+
mlir::math::MathDialect, mlir::arith::ArithDialect, mlir::scf::SCFDialect,
132+
mlir::gpu::GPUDialect, mlir::LLVM::LLVMDialect, mlir::NVVM::NVVMDialect,
129133
mlir::triton::nvgpu::NVGPUDialect, mlir::triton::nvws::NVWSDialect,
130134
mlir::triton::amdgpu::TritonAMDGPUDialect,
131135
mlir::triton::proton::ProtonDialect, mlir::ROCDL::ROCDLDialect,

include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,11 @@ void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter,
102102
const TargetInfoBase &targetInfo,
103103
PatternBenefit benefit);
104104

105+
void populateInstrumentationToLLVMPatterns(LLVMTypeConverter &typeConverter,
106+
const TargetInfoBase &targetInfo,
107+
RewritePatternSet &patterns,
108+
PatternBenefit benefit);
109+
105110
} // namespace triton
106111
} // namespace mlir
107112

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -537,12 +537,6 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
537537
const TargetInfoBase &target,
538538
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);
539539

540-
[[nodiscard]] bool emitTransferBetweenRegistersAndShared(
541-
LinearLayout &regLayout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy,
542-
std::optional<int32_t> maxVecElems, const SharedMemoryObject &smemObj,
543-
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
544-
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);
545-
546540
[[nodiscard]] bool emitTransferBetweenRegistersAndShared(
547541
LinearLayout &regLayout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy,
548542
std::optional<int32_t> maxVecElems, const SharedMemoryObject &smemObj,
@@ -574,6 +568,15 @@ lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
574568
ConversionPatternRewriter &rewriter,
575569
const TargetInfoBase &targetInfo);
576570

571+
// Lower local_load/local_store via ld.shared/st.shared
572+
SmallVector<Value> lowerLocalLdSt(Location loc, MLIRContext *ctx,
573+
// Map from registers to offset
574+
LinearLayout cvt, ArrayRef<Value> valsArray,
575+
// Input for store, output for load
576+
Type llvmElemTy, Value smemBase,
577+
ConversionPatternRewriter &rewriter,
578+
const TargetInfoBase &targetInfo);
579+
577580
SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
578581
RewriterBase &rewriter);
579582

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
add_subdirectory(Triton)
22
add_subdirectory(TritonGPU)
33
add_subdirectory(TritonNvidiaGPU)
4+
add_subdirectory(TritonInstrument)

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 91 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -167,21 +167,22 @@ def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> {
167167
];
168168
}
169169

170-
def SwizzledSharedEncodingAttr :
171-
TritonGPU_Attr<"SwizzledSharedEncoding", "swizzled_shared_encoding", [SharedEncodingTrait, LayoutEncodingTrait]> {
170+
def SwizzledSharedEncodingAttr
171+
: TritonGPU_Attr<"SwizzledSharedEncoding", "swizzled_shared_encoding",
172+
[SharedEncodingTrait, LayoutEncodingTrait]> {
172173
let mnemonic = "swizzled_shared";
173174

174175
let description = [{
175176
An encoding for tensors whose elements may be simultaneously accessed by
176-
different cuda threads in the programs, via shared memory. In other words,
177+
different GPU threads in the programs, via shared memory. In other words,
177178
for all indices i \in Z^d, \mathcal{L}(i) = {0, 1, ..., 32*num_warps - 1}.
178179

179180
In order to avoid shared memory bank conflicts, elements may be swizzled.
180181
Here are some examples. In all cases, the input tensor is [0, 1, ..., n-1].
181182

182183
1. Basic swizzling
183184

184-
#shared<{vec=1, perPhase=1, maxPhase=4, order=[1,0]}>
185+
#ttg.swizzled_shared<{vec=1, perPhase=1, maxPhase=4, order=[1,0]}>
185186
[ 0, 1, 2, 3], // xor with 0
186187
[ 5, 4, 7, 6], // xor with 1
187188
[10, 11, 8, 9], // xor with 2
@@ -192,7 +193,7 @@ out[r][c^r]).
192193

193194
2. Multiple rows per phase
194195

195-
#shared<{vec=1, perPhase=2, maxPhase=4, order=[1,0]}>
196+
#ttg.swizzled_shared<{vec=1, perPhase=2, maxPhase=4, order=[1,0]}>
196197
[ 0, 1, 2, 3], // phase 0 (xor with 0)
197198
[ 4, 5, 6, 7],
198199
[ 9, 8, 11, 10], // phase 1 (xor with 1)
@@ -203,7 +204,7 @@ means that pairs of 2 rows get the same swizzling.
203204

204205
3. Max-phase applied
205206

206-
$shared<{vec=1, perPhase=1, maxPhase=2, order=[1,0]}>
207+
#ttg.swizzled_shared<{vec=1, perPhase=1, maxPhase=2, order=[1,0]}>
207208
[ 0, 1, 2, 3], // phase 0 (xor with 0)
208209
[ 5, 4, 7, 6], // phase 1 (xor with 1)
209210
[ 8, 9, 10, 11], // phase 0
@@ -218,7 +219,7 @@ effect of limiting the maximum value of the xor to m-1.
218219

219220
4. Max-phase and per-phase
220221

221-
#shared<{vec=1, perPhase=2, maxPhase=2, order=[1,0]}>
222+
#ttg.swizzled_shared<{vec=1, perPhase=2, maxPhase=2, order=[1,0]}>
222223
[ 0, 1, 2, 3], // phase 0 (xor with 0)
223224
[ 4, 5, 6, 7], // phase 0
224225
[ 9, 8, 11, 10], // phase 1 (xor with 1)
@@ -234,7 +235,7 @@ maximum value of maxPhase-1. In other words, elements of row r are xor'ed with
234235

235236
5. Adding vec
236237

237-
#shared<{vec=2, perPhase=1, maxPhase=4, order=[1,0]}>
238+
#ttg.swizzled_shared<{vec=2, perPhase=1, maxPhase=4, order=[1,0]}>
238239
[ 0, 1, 2, 3, 4, 5, 6, 7],
239240
[10, 11, 8, 9, 14, 15, 12, 13],
240241
[20, 21, 22, 23, 16, 17, 18, 19],
@@ -383,6 +384,88 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at
383384
let genVerifyDecl = 1;
384385
}
385386

387+
def PaddeddSharedEncodingAttr
388+
: TritonGPU_Attr<"PaddedSharedEncoding", "padded_shared_encoding",
389+
[SharedEncodingTrait, LayoutEncodingTrait]> {
390+
let mnemonic = "padded_shared";
391+
392+
let description = [{
393+
An encoding for tensors whose elements may be simultaneously accessed by
394+
different GPU threads in the programs, via shared memory. In other words,
395+
for all indices i \in Z^d, \mathcal{L}(i) = {0, 1, ..., 32*num_warps - 1}.
396+
Compared to SwizzledSharedEncodingAttr, this encoding uses padding to avoid
397+
shared memory bank conflicts.
398+
399+
Formally, given a layout:
400+
padded_shared<[<interval_0>:+<pad_0>, <interval_1>:+<pad_1>, ...]>
401+
We insert a padding of `<pad_i>` elements after every `<interval_i>` elements.
402+
Multi interval-padding pairs are supported for flexibility of multi tiered
403+
padding schemes; they compose in an additive manner. So for a 1-D tensor element
404+
at index i, the corresponding shared memory location index is
405+
i + \sum_{k} (i / interval_k) * pad_k = 1
406+
`<interval_i>` and `<pad_i>` all need to be power of two.
407+
408+
Some concrete examples, using `eM` to mean tensor elements and `pN` to mean
409+
padding:
410+
411+
1. Single interval-padding pair:
412+
413+
#ttg.padded_shared<[2:+2]>
414+
[e0, e1, p0, p1,
415+
e2, e3, p2, p3,
416+
...]
417+
418+
2. Double interval-padding pairs:
419+
420+
#ttg.padded_shared<[2:+1, 4:+2]>
421+
[e0, e1, p0,
422+
e2, e3, p1, p2, p3,
423+
e4, e5, p4,
424+
e6, e7, p5, p6, p7,
425+
...]
426+
427+
In addition to interval-padding pairs, this encoding requires an `order` to
428+
specify the logical tensor dimenions from the fastest-to slowest-varying.
429+
It may optionally support CGA level organization like other encoding
430+
attributes too, for example,
431+
#ttg.padded_shared<[2:+1, 4:+2] {
432+
order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1],
433+
CTAOrder = [0, 1]}>
434+
}];
435+
436+
let parameters = (ins
437+
ArrayRefParameter<"unsigned">:$intervals,
438+
ArrayRefParameter<"unsigned">:$paddings,
439+
// Order of logical tensor dimensions; fastest-varying first.
440+
ArrayRefParameter<"unsigned">:$order,
441+
"CTALayoutAttr":$CTALayout
442+
);
443+
444+
let builders = [
445+
AttrBuilder<(ins "ArrayRef<std::pair<unsigned, unsigned>>":$intervalPads,
446+
"ArrayRef<unsigned>":$order, "CTALayoutAttr":$ctaLayout)>,
447+
];
448+
449+
let extraClassDeclaration = extraBaseClassDeclaration # [{
450+
unsigned getRank() const { return getOrder().size(); }
451+
int32_t getAlignment() const { return 16; }
452+
453+
unsigned getMinInterval() const {
454+
return *llvm::min_element(getIntervals());
455+
}
456+
457+
// Returns the total number of elements including padding given the input
458+
// tensor shape.
459+
int64_t getPaddedSize(ArrayRef<int64_t> shape) const;
460+
461+
SmallVector<unsigned> getCTAsPerCGA() const;
462+
SmallVector<unsigned> getCTAOrder() const;
463+
SmallVector<unsigned> getCTASplitNum() const;
464+
}];
465+
let hasCustomAssemblyFormat = 1;
466+
let genVerifyDecl = 1;
467+
}
468+
386469
def NVMMASharedEncodingAttr :
387470
TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_shared_encoding", [SharedEncodingTrait, LayoutEncodingTrait]> {
388471
let mnemonic = "nvmma_shared";
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
add_subdirectory(IR)
2+
add_subdirectory(Transforms)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR})
2+
3+
set(LLVM_TARGET_DEFINITIONS TritonInstrumentDialect.td)
4+
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=tti)
5+
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=tti)
6+
add_mlir_doc(TritonInstrumentDialect TritonInstrumentDialect dialects/ -gen-dialect-doc)
7+
8+
set(LLVM_TARGET_DEFINITIONS TritonInstrumentOps.td)
9+
mlir_tablegen(Ops.h.inc -gen-op-decls)
10+
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
11+
add_mlir_doc(TritonInstrumentOps TritonInstrumentOps dialects/ -gen-op-doc)
12+
13+
add_public_tablegen_target(TritonInstrumentTableGen)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#ifndef TRITON_DIALECT_TRITONINSTRUMENT_IR_DIALECT_H_
2+
#define TRITON_DIALECT_TRITONINSTRUMENT_IR_DIALECT_H_
3+
4+
// TritonInstrument depends on Triton and TritonGPU
5+
#include "triton/Dialect/Triton/IR/Dialect.h"
6+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
7+
8+
#define GET_OP_CLASSES
9+
#include "triton/Dialect/TritonInstrument/IR/Dialect.h.inc"
10+
#include "triton/Dialect/TritonInstrument/IR/Ops.h.inc"
11+
12+
#endif // TRITON_DIALECT_TRITONINSTRUMENT_IR_DIALECT_H_
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#ifndef TRITONINSTRUMENT_DIALECT
2+
#define TRITONINSTRUMENT_DIALECT
3+
4+
include "mlir/IR/OpBase.td"
5+
6+
def TritonInstrument_Dialect : Dialect {
7+
let name = "tti";
8+
let cppNamespace = "::mlir::triton::instrument";
9+
}
10+
11+
#endif // TRITONINSTRUMENT_DIALECT
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#ifndef TRITONINSTRUMENT_OPS
2+
#define TRITONINSTRUMENT_OPS
3+
4+
include "triton/Dialect/TritonInstrument/IR/TritonInstrumentDialect.td"
5+
include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td"
6+
include "triton/Dialect/Triton/IR/TritonTypes.td"
7+
include "mlir/IR/OpBase.td"
8+
include "mlir/Interfaces/SideEffectInterfaces.td"
9+
10+
class TTI_Op<string mnemonic, list<Trait> traits = []> :
11+
Op<TritonInstrument_Dialect, mnemonic, traits> {
12+
}
13+
14+
// Define an array of pointers to shared memory buffers
15+
def TTI_ExperimentalSharedBufferPointersOp : TTI_Op<"experimental_shared_buffer_pointers", [Pure]> {
16+
let summary = "definte an array of pointers to shared memory buffers";
17+
let description = [{
18+
Create a tensor of pointers to shared memory buffers.
19+
}];
20+
let arguments = (ins DenseI32ArrayAttr:$offsets);
21+
let results = (outs TT_Tensor:$result);
22+
let assemblyFormat = [{
23+
attr-dict `:` type($result)
24+
}];
25+
}
26+
27+
// Check if writing to a buffer guarded by a mbar is valid
28+
def TTI_ExperimentalCheckAsyncWriteWithMbarSharedOp : TTI_Op<"experimental_check_async_write_with_mbar_shared", [Pure]> {
29+
let summary = "check if writing to a buffer guarded by a mbar is valid";
30+
let description = [{
31+
Check if writing to a shared memory buffer guarded by a mbar is valid.
32+
Update the buffer state and assert if the buffer is being read or written.
33+
}];
34+
let arguments = (ins
35+
TTG_MemDescType:$buffer,
36+
TTG_MemDescType:$mbar,
37+
TT_Tensor:$buffers,
38+
TT_Tensor:$states,
39+
TT_Tensor:$barriers
40+
);
41+
let results = (outs
42+
TT_Tensor:$outStates,
43+
TT_Tensor:$outBarriers
44+
);
45+
let assemblyFormat = [{
46+
$buffer `,` $mbar `{` $buffers `,` $states `,` $barriers `}` attr-dict `:` type($buffer) `,` type($mbar) `,` type($buffers) `,` type($states) `,` type($barriers) `->` type($outStates) `,` type($outBarriers)
47+
}];
48+
let builders = [
49+
OpBuilder<(ins "Value":$buffer, "Value":$mbar, "Value":$buffers, "Value":$states, "Value":$barriers),[{
50+
build($_builder, $_state, {states.getType(), barriers.getType()}, buffer, mbar, buffers, states, barriers);
51+
}]>
52+
];
53+
}
54+
55+
def TTI_ExperimentalCheckWaitMbarOp : TTI_Op<"experimental_check_wait_mbar", [Pure]> {
56+
let summary = "check if waiting on a mbar is valid and update the barrier state";
57+
let description = [{
58+
Check if waiting on a mbar is valid and update the barrier state.
59+
}];
60+
let arguments = (ins
61+
TTG_MemDescType:$mbar,
62+
TT_Tensor:$barriers,
63+
TT_Tensor:$states
64+
);
65+
66+
let results = (outs
67+
TT_Tensor:$outStates,
68+
TT_Tensor:$outBarriers);
69+
70+
let assemblyFormat = [{
71+
$mbar `{` $states `,` $barriers `}` attr-dict `:` type($mbar) `,` type($states) `,` type($barriers) `->` type($outStates) `,` type($outBarriers)
72+
}];
73+
74+
let builders = [
75+
OpBuilder<(ins "Value":$mbar, "Value":$barriers, "Value":$states),
76+
[{
77+
build($_builder, $_state, {states.getType(), barriers.getType()}, mbar, barriers, states);
78+
}]>];
79+
80+
}
81+
82+
#endif // TRITONINSTRUMENT_OPS

0 commit comments

Comments
 (0)