Skip to content

Commit 039e84a

Browse files
Merge commit 'eb6654624b4acb937795b00123dbd48da4738d0b'
2 parents bc82b95 + eb66546 commit 039e84a

40 files changed

+1418
-512
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,6 @@ def TTG_MemDescIndexOp : TTG_Op<"memdesc_index", [Pure, MemDescViewTrait]> {
214214
- the output shape is 4x16xf16, and
215215
- index = 1.
216216
Then the output descriptor is equivalent to input[1], where input is the logical tensor.
217-
218-
When the input is of rank 1 (i.e, shape=[k]), the output will have shape=[1].
219217
}];
220218

221219
let arguments = (ins TTG_MemDescType:$src, I32:$index);

include/triton/Dialect/TritonInstrument/IR/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ add_mlir_doc(TritonInstrumentDialect TritonInstrumentDialect dialects/ -gen-dial
88
set(LLVM_TARGET_DEFINITIONS TritonInstrumentOps.td)
99
mlir_tablegen(Ops.h.inc -gen-op-decls)
1010
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
11+
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
12+
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
1113
add_mlir_doc(TritonInstrumentOps TritonInstrumentOps dialects/ -gen-op-doc)
1214

1315
add_public_tablegen_target(TritonInstrumentTableGen)

include/triton/Dialect/TritonInstrument/IR/Dialect.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include "triton/Dialect/Triton/IR/Dialect.h"
66
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
77

8+
#include "triton/Dialect/TritonInstrument/IR/OpsEnums.h.inc"
9+
810
#define GET_OP_CLASSES
911
#include "triton/Dialect/TritonInstrument/IR/Dialect.h.inc"
1012
#include "triton/Dialect/TritonInstrument/IR/Ops.h.inc"
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#ifndef TRITONINSTRUMENT_ATTR_DEFS
2+
#define TRITONINSTRUMENT_ATTR_DEFS
3+
4+
include "mlir/IR/EnumAttr.td"
5+
6+
def TT_MemTypeAttr : I32EnumAttr<
7+
"MemType", "",
8+
[
9+
I32EnumAttrCase<"SHARED", 0, "shared">,
10+
I32EnumAttrCase<"TENSOR", 1, "tensor">,
11+
]> {
12+
let cppNamespace = "::mlir::triton::instrument";
13+
}
14+
15+
#endif // TRITONINSTRUMENT_ATTR_DEFS

include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td

Lines changed: 76 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td"
66
include "triton/Dialect/Triton/IR/TritonTypes.td"
77
include "mlir/IR/OpBase.td"
88
include "mlir/Interfaces/SideEffectInterfaces.td"
9+
include "triton/Dialect/TritonInstrument/IR/TritonInstrumentAttrDefs.td"
910

1011
//
1112
// Interfaces
@@ -33,33 +34,43 @@ def TTI_ExperimentalAssertInThreadOp : TTI_Op<"experimental_assert_in_thread", [
3334
}
3435

3536

36-
def TTI_ExperimentalSharedBufferPointersOp : TTI_Op<"experimental_shared_buffer_pointers", [Pure]> {
37+
def TTI_ExperimentalBufferPointersOp : TTI_Op<"experimental_buffer_pointers", [Pure]> {
3738
let summary = "definte an array of pointers to shared memory buffers";
3839
let description = [{
3940
Create a tensor of pointers to shared memory buffers.
4041
}];
41-
let arguments = (ins DenseI32ArrayAttr:$offsets);
42+
let arguments = (ins DenseI32ArrayAttr:$offsets, TT_MemTypeAttr:$memType);
4243
let results = (outs TT_Tensor:$result);
4344
let assemblyFormat = [{
44-
attr-dict `:` type($result)
45+
$offsets `,` $memType attr-dict `:` type($result)
4546
}];
4647
}
4748

4849

4950
def TTI_ExperimentalCheckOutstandingWritesOp : TTI_Op<"experimental_check_outstanding_writes", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
5051
let summary = "check if there are outstanding writes to a buffer guarded by a mbar";
5152
let description = [{
52-
Check if there are outstanding writes to a buffer guarded by a mbar.
53+
Check if the writeState tensor has non-zero value associated with the buffer.
54+
55+
`writeState` is a tensor of 8b bitfields, where:
56+
- bit 0: 1 if the buffer is being written to
57+
- bit 1: 1 if the write is *not* hwPipelined
58+
59+
If hwPipelined is true, shift the bitfield by 1 to check the second bit - this
60+
means that the error won't be triggered if another pipelined write is outstanding.
5361
}];
5462
let arguments = (ins
5563
TTG_MemDescType:$buf,
5664
TT_Tensor:$buffers,
5765
TT_PtrLike:$writeBars,
5866
TypeAttr:$writeBarsType,
67+
TT_PtrLike:$writeState,
68+
TypeAttr:$writeStateType,
69+
I1Attr:$hwPipelined,
5970
Optional<I1>:$pred
6071
);
6172
let assemblyFormat = [{
62-
$buf `{` $buffers `,` $writeBars `(` $writeBarsType `)` `}` (`,` $pred^)? attr-dict `:` type($buf) `,` type($buffers) `,` type($writeBars)
73+
$buf `{` $buffers `,` $writeBars `(` $writeBarsType `)` `,` $writeState `(` $writeStateType `)` `}` (`,` $pred^)? `pipelined` $hwPipelined attr-dict `:` type($buf) `,` type($buffers) `,` type($writeBars) `,` type($writeState)
6374
}];
6475
let hasVerifier = 1;
6576
}
@@ -87,18 +98,49 @@ def TTI_ExperimentalCheckOutstandingReadsOp : TTI_Op<"experimental_check_outstan
8798
def TTI_ExperimentalMarkAsWriteOp : TTI_Op<"experimental_mark_as_write", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
8899
let summary = "mark a buffer as being written to using mbar as a guard";
89100
let description = [{
90-
Mark a buffer as being written to using mbar as a guard.
101+
Mark a buffer as being written to. It is not yet tracked by a barrier, until
102+
`commit_write_with_barrier` is called, at which point all the buffers being written
103+
to are marked as tracked by the barrier.
104+
105+
`writeState` is a tensor of 8b bitfields, where:
106+
- bit 0: 1 if the buffer is being written to
107+
- bit 1: 1 if the write is *not* hwPipelined
108+
109+
If hwPipelined is true, the write won't trigger an error if another pipelined
110+
write is executed later without waiting for the barrier.
91111
}];
92112
let arguments = (ins
93113
TTG_MemDescType:$buf,
94-
TTG_MemDescType:$mbar,
95114
TT_Tensor:$buffers,
115+
TT_PtrLike:$writeState,
116+
TypeAttr:$writeStateType,
117+
I1Attr:$hwPipelined,
118+
Optional<I1>:$pred
119+
);
120+
let assemblyFormat = [{
121+
$buf `{` $buffers `,` $writeState `(` $writeStateType `)` `}` (`,` $pred^)? `pipelined` $hwPipelined attr-dict `:` type($buf) `,` type($buffers) `,` type($writeState)
122+
}];
123+
let hasVerifier = 1;
124+
}
125+
126+
127+
def TTI_ExperimentalCommitWriteWithBarrierOp : TTI_Op<"experimental_commit_write_with_barrier", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
128+
let summary = "Mark all buffers being currently written as tracked by the barrier.";
129+
let description = [{
130+
For all buffers currently marked in writeState tensor, mark them as tracked by the mbar in
131+
writeBars tensor.
132+
}];
133+
let arguments = (ins
134+
TTG_MemDescType:$mbar,
135+
TT_Tensor:$barriers,
96136
TT_PtrLike:$writeBars,
97137
TypeAttr:$writeBarsType,
138+
TT_PtrLike:$writeState,
139+
TypeAttr:$writeStateType,
98140
Optional<I1>:$pred
99141
);
100142
let assemblyFormat = [{
101-
$buf `,` $mbar `{` $buffers `,` $writeBars `(` $writeBarsType `)` `}` (`,` $pred^)? attr-dict `:` type($buf) `,` type($mbar) `,` type($buffers) `,` type($writeBars)
143+
$mbar `{` $barriers `,` $writeBars `(` $writeBarsType `)` `,` $writeState `(` $writeStateType `)` `}` (`,` $pred^)? attr-dict `:` type($mbar) `,` type($barriers) `,` type($writeBars) `,` type($writeState)
102144
}];
103145
let hasVerifier = 1;
104146
}
@@ -132,13 +174,17 @@ def TTI_ExperimentalClearWriteBarrierOp : TTI_Op<"experimental_clear_write_barri
132174
}];
133175
let arguments = (ins
134176
TTG_MemDescType:$mbar,
177+
TT_Tensor:$barriers,
135178
TT_PtrLike:$writeBars,
136179
TypeAttr:$writeBarsType,
180+
TT_PtrLike:$writeState,
181+
TypeAttr:$writeStateType,
137182
Optional<I1>:$pred
138183
);
139184
let assemblyFormat = [{
140-
$mbar `{` $writeBars `(` $writeBarsType `)` `}` (`,` $pred^)? attr-dict `:` type($mbar) `,` type($writeBars)
185+
$mbar `{` $barriers `,` $writeBars `(` $writeBarsType `)` `,` $writeState `(` $writeStateType `)` `}` (`,` $pred^)? attr-dict `:` type($mbar) `,` type($barriers) `,` type($writeBars) `,` type($writeState)
141186
}];
187+
let hasVerifier = 1;
142188
}
143189

144190

@@ -160,6 +206,27 @@ def TTI_ExperimentalClearReadBarrierOp : TTI_Op<"experimental_clear_read_barrier
160206
let hasVerifier = 1;
161207
}
162208

209+
210+
def TTI_ExperimentalCheckBarrierWritesClearedOp : TTI_Op<"experimental_check_barrier_writes_cleared", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
211+
let summary = "verify that the barrier is not used to track any writes";
212+
let description = [{
213+
Verify that the barrier is not used to track any writes.
214+
}];
215+
let arguments = (ins
216+
TTG_MemDescType:$mbar,
217+
TT_Tensor:$barriers,
218+
TT_PtrLike:$writeBars,
219+
TypeAttr:$writeBarsType,
220+
Optional<I1>:$pred
221+
);
222+
let assemblyFormat = [{
223+
$mbar `{` $barriers `,` $writeBars `(` $writeBarsType `)` `}` (`,` $pred^)? attr-dict `:` type($mbar) `,` type($barriers) `,` type($writeBars)
224+
}];
225+
let hasVerifier = 1;
226+
}
227+
228+
229+
// TODO: Potentially resolve the naming/functionality clash with commit_write_with_barrier
163230
def TTI_ExperimentalStageWriteForCommitOp : TTI_Op<"experimental_stage_write_for_commit", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
164231
let summary = "Preapre to an async copy of a buffer. Staged until commit_group is called.";
165232
let description = [{

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -477,17 +477,21 @@ struct MemDescIndexOpConversion
477477
auto *ctx = op->getContext();
478478
auto b = TritonLLVMOpBuilder(loc, rewriter);
479479
auto srcTy = op.getSrc().getType();
480-
auto destTy = op.getResult().getType();
480+
auto dstTy = op.getResult().getType();
481481
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
482482

483+
// getAllocationShapePerCTA returns the correct number fp4 elements that we
484+
// need to skip when we have fp4Padded=True. getShapePerCTA does not account
485+
// for this
486+
auto stride = product(
487+
getAllocationShapePerCTA(dstTy.getEncoding(), dstTy.getShape()));
488+
Value offset = b.mul(op.getIndex(), b.i32_val(stride));
483489
auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),
484490
llvmElemTy, rewriter);
485491
auto base = smemObj.getBase();
486492
auto elemPtrTy = base.getType();
487-
Value stride = smemObj.getStrides(srcTy, loc, rewriter).front();
488-
Value offset = b.mul(op.getIndex(), stride);
489493
auto prevOffsets = smemObj.getOffsets();
490-
SmallVector<Value> offsetVals(prevOffsets.end() - destTy.getRank(),
494+
SmallVector<Value> offsetVals(prevOffsets.end() - dstTy.getRank(),
491495
prevOffsets.end());
492496
// Advance the pointer and keep the opOffsets as the new shape
493497
smemObj = SharedMemoryObject(b.gep(elemPtrTy, llvmElemTy, base, offset),

lib/Conversion/TritonInstrumentToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ add_triton_library(TritonInstrumentToLLVM
77
TritonIR
88
TritonGPUIR
99
TritonInstrumentIR
10+
TritonNvidiaGPUIR
1011
)

0 commit comments

Comments
 (0)