Skip to content

Commit 1cf06c5

Browse files
authored
[IR] Add typing for tensor descriptor types (#5147)
Currently tensor descriptors are just typed as `!tt.ptr<i8>` which is exposing the assumption it's using a TMA descriptor. This changes it to a custom type `!tt.tensordesc<tensor<...>>` which is lowered to a pointer type in the LLVM IR. I also add two new IR Ops which are used to cast between pointers and tensordesc objects. ```mlir tt.reinterpret_tensor_descriptor %ptr : !tt.ptr<i8> to !tt.tensordesc<...> triton_nvidia_gpu.tensor_desc_to_tma_ptr %desc : !tt.tensordesc<...> -> !tt.ptr<i8> ``` Really both of these should be nvidia-specific but the first is exposed in the triton IR to keep support for the by-value TMA descriptor API around while we figure out if it's possible to update to the new style.
1 parent 38c6284 commit 1cf06c5

File tree

21 files changed

+263
-105
lines changed

21 files changed

+263
-105
lines changed

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -956,9 +956,10 @@ def TT_MakeTensorPtrOp : TT_Op<"make_tensor_ptr",
956956
//
957957
// Make Tensor Descriptor Op
958958
//
959-
def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor",
960-
[Pure,
961-
SameVariadicOperandSize]> {
959+
def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor", [
960+
Pure,
961+
SameVariadicOperandSize,
962+
]> {
962963
let summary = "Make a tensor descriptor type with meta information of the parent tensor and block size";
963964

964965
let description = [{
@@ -969,23 +970,38 @@ def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor",
969970
let arguments = (ins
970971
TT_Ptr:$base,
971972
Variadic<I32>:$shape,
972-
Variadic<I64>:$strides,
973-
DenseI32ArrayAttr:$tensorShape
973+
Variadic<I64>:$strides
974974
);
975975

976-
// TODO(peterbell10): define a custom IR type to represent descriptors
977-
let results = (outs TT_Ptr:$result);
976+
let results = (outs TT_TensorDescType:$result);
978977

979978
let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` attr-dict `:` type($base) `,` type($result)";
980979

981980
let builders = [
982-
OpBuilder<(ins
983-
"Value":$base,
984-
"ValueRange":$shape,
985-
"ValueRange":$strides,
986-
"ArrayRef<int32_t>":$tensorShape
987-
)>
981+
OpBuilder<(ins "Value":$base, "ValueRange":$shape, "ValueRange":$strides, "ArrayRef<int32_t>":$blockShape)>
988982
];
983+
984+
let extraClassDeclaration = [{
985+
ArrayRef<int64_t> getTensorShape() {
986+
return getType().getBlockType().getShape();
987+
}
988+
}];
989+
}
990+
991+
def ReinterpretTensorDescOp : TT_Op<"reinterpret_tensor_descriptor", [Pure]> {
992+
let summary = "Reinterpret a pointer as a tensor descriptor";
993+
994+
let description = [{
995+
This Op exists to help the transition from untyped raw TMA objects to typed Tensor descriptor objects.
996+
Ideally, we can remove this once the APIs are fully fleshed out.
997+
}];
998+
999+
let arguments = (ins TT_Ptr:$rawDesc);
1000+
let results = (outs TT_TensorDescType:$result);
1001+
1002+
let assemblyFormat = [{
1003+
$rawDesc attr-dict `:` qualified(type($rawDesc)) `to` qualified(type($result))
1004+
}];
9891005
}
9901006

9911007
// The following ops, including `call`, `func`, and `return` are copied and modified from
@@ -1195,20 +1211,19 @@ def ReturnOp : TT_Op<"return", [Pure, HasParent<"FuncOp">, /*MemRefsNormalizable
11951211
}
11961212

11971213

1198-
def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [
1199-
MemoryEffects<[MemRead<GlobalMemory>]>]> {
1214+
def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [MemoryEffects<[MemRead<GlobalMemory>]>]> {
12001215
let summary = "Load from descriptor";
12011216
let description = [{
12021217
This operation will be lowered to Nvidia TMA load operation on targets supporting it.
1203-
`desc_ptr` is a pointer to the TMA descriptor allocated in global memory.
1218+
`desc` is a tensor descriptor object.
12041219
The destination tensor type and shape must match the descriptor otherwise the result is undefined.
12051220

12061221
This is an escape hatch and is only there for testing/experimenting.
12071222
This op will be removed in the future.
12081223
}];
12091224
let arguments = (
12101225
ins
1211-
TT_PtrType:$desc_ptr,
1226+
TT_TensorDescType:$desc,
12121227
Variadic<I32>:$indices,
12131228
DefaultValuedAttr<TT_CacheModifierAttr, "::mlir::triton::CacheModifier::NONE">:$cache,
12141229
DefaultValuedAttr<TT_EvictionPolicyAttr, "::mlir::triton::EvictionPolicy::NORMAL">:$evict
@@ -1217,36 +1232,37 @@ def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [
12171232
let results = (outs TT_Tensor:$result);
12181233

12191234
let assemblyFormat = [{
1220-
$desc_ptr `[` $indices `]`
1235+
$desc `[` $indices `]`
12211236
oilist(
12221237
`cacheModifier` `=` $cache |
12231238
`evictionPolicy` `=` $evict
12241239
)
1225-
attr-dict `:` qualified(type($desc_ptr)) `->` type($result)
1240+
attr-dict `:` qualified(type($desc)) `->` type($result)
12261241
}];
12271242
}
12281243

12291244
def TT_ExperimentalDescriptorStoreOp : TT_Op<"experimental_descriptor_store", [
1230-
MemoryEffects<[MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>]> {
1245+
MemoryEffects<[MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>,
1246+
]> {
12311247
let summary = "store value based on descriptor";
12321248
let description = [{
12331249
This operation will be lowered to Nvidia TMA store operation on targets supporting it.
1234-
`desc_ptr` is a pointer to the TMA descriptor allocated in global memory.
1250+
`desc` is a tensor descriptor object.
12351251
The shape and types of `src` must match the descriptor otherwise the result is undefined.
12361252

12371253
This is an escape hatch and is only there for testing/experimenting.
12381254
This op will be removed in the future.
12391255
}];
12401256
let arguments = (
12411257
ins
1242-
TT_PtrType:$desc_ptr,
1258+
TT_TensorDescType:$desc,
12431259
TT_Tensor:$src,
12441260
Variadic<I32>:$indices
12451261
);
12461262

12471263
let assemblyFormat = [{
1248-
$desc_ptr `[` $indices `]` `,` $src
1249-
attr-dict `:` qualified(type($desc_ptr)) `,` type($src)
1264+
$desc `[` $indices `]` `,` $src
1265+
attr-dict `:` qualified(type($desc)) `,` type($src)
12501266
}];
12511267
}
12521268

include/triton/Dialect/Triton/IR/TritonTypes.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,5 +140,16 @@ def TT_MemDescType : TritonTypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]>
140140
let hasCustomAssemblyFormat = 1;
141141
}
142142

143+
// Result type of ExperimentalMakeTensorDescriptor
144+
def TT_TensorDescType : TritonTypeDef<"TensorDesc", "tensordesc", []> {
145+
let summary = "Tensor descriptor type (`::mlir::triton::TensorDescType`) in Triton IR type system";
146+
147+
let description = [{
148+
A portable abstraction for nvidia-TMA descriptors.
149+
}];
150+
151+
let parameters = (ins "RankedTensorType":$blockType);
152+
let assemblyFormat = "`<` $blockType `>`";
153+
}
143154

144155
#endif

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ def TTG_GlobalScratchAllocOp : TTG_Op<"global_scratch_alloc", [MemoryEffects<[Me
295295
$_builder.getI32IntegerAttr(nbytes), $_builder.getI32IntegerAttr(alignment)); }]>
296296
];
297297

298-
let assemblyFormat = [{attr-dict `:` type($result)}];
298+
let assemblyFormat = [{attr-dict `:` qualified(type($result))}];
299299
}
300300

301301
#endif

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,26 @@ def TTNG_WaitBarrierOp : TTNG_Op<"wait_barrier", [DeclareOpInterfaceMethods<Memo
185185
let assemblyFormat = "$alloc `,` $phase attr-dict `:` type($alloc)";
186186
}
187187

188+
def TTNG_TensorDescToTMAPtrOp : TTNG_Op<"tensor_desc_to_tma_ptr", [Pure]> {
189+
let summary = "Convert tensor descriptor to pointer to tma descriptor";
190+
191+
let arguments = (ins TT_TensorDescType:$desc);
192+
let results = (outs TT_Ptr:$ptr);
193+
194+
let assemblyFormat = [{
195+
$desc attr-dict `:` qualified(type($desc)) `to` qualified(type($ptr))
196+
}];
197+
198+
let builders = [
199+
OpBuilder<(ins "Value":$desc), [{
200+
auto ptrTy = triton::PointerType::get($_builder.getI8Type(), 1);
201+
build($_builder, $_state, ptrTy, desc);
202+
}]>
203+
];
204+
205+
let hasCanonicalizeMethod = 1;
206+
}
207+
188208

189209
def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
190210
let summary = "copy data based on descriptor from global memory to local memory asynchronously";

lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
2828
addConversion([&](MemDescType type) -> std::optional<Type> {
2929
return convertMemDescType(type, targetInfo);
3030
});
31+
addConversion([](TensorDescType type) -> std::optional<Type> {
32+
auto ctx = type.getContext();
33+
return LLVM::LLVMPointerType::get(ctx, 1);
34+
});
3135
addConversion([&](triton::gpu::AsyncTokenType type) -> std::optional<Type> {
3236
return convertAsyncToken(type);
3337
});

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "triton/Dialect/Triton/IR/Dialect.h"
99
#include "triton/Dialect/Triton/IR/Types.h"
1010
#include "triton/Dialect/Triton/IR/Utility.h"
11+
#include "llvm/Support/ErrorHandling.h"
1112

1213
namespace mlir {
1314
namespace triton {
@@ -863,12 +864,17 @@ OpFoldResult AdvanceOp::fold(FoldAdaptor adaptor) {
863864
//-- MakeTensorDescOp --
864865
void MakeTensorDescOp::build(OpBuilder &builder, OperationState &state,
865866
Value base, ValueRange shape, ValueRange strides,
866-
ArrayRef<int32_t> tensorShape) {
867-
auto resultTy = getPointerType(builder.getI8Type());
868-
assert(resultTy.getContext());
867+
ArrayRef<int32_t> blockShape) {
868+
auto ptrTy = dyn_cast<triton::PointerType>(base.getType());
869+
if (!ptrTy) {
870+
llvm::report_fatal_error("Expected pointer type");
871+
}
872+
auto elemTy = ptrTy.getPointeeType();
869873

870-
return build(builder, state, resultTy, base, shape, strides,
871-
builder.getDenseI32ArrayAttr(tensorShape));
874+
SmallVector<int64_t> blockShape64(blockShape);
875+
auto blockTy = RankedTensorType::get(blockShape64, elemTy);
876+
auto descTy = TensorDescType::get(builder.getContext(), blockTy);
877+
return build(builder, state, descTy, base, shape, strides);
872878
}
873879

874880
// The following ops, including `call`, `func`, and `return` are copied and

lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@ namespace tt = mlir::triton;
3434
namespace ttg = mlir::triton::gpu;
3535
namespace ttng = mlir::triton::nvidia_gpu;
3636

37-
// TODO: We can extra some helpers into common utilities once we add more
37+
// TODO: We can extract some helpers into common utilities once we add more
3838
// schedules.
3939

4040
namespace {
4141

4242
struct LoadInfo {
43-
// Layout of the data in the shared memory.
43+
// Layout of the data in shared memory.
4444
ttg::SharedEncodingAttr sharedEncoding = nullptr;
4545
// Blocked encoding is used for loads not used by the dot.
4646
ttg::BlockedEncodingAttr blockedEncoding = nullptr;
@@ -239,9 +239,11 @@ createTMAAsyncCopy(scf::ForOp &forOp, tt::ExperimentalDescriptorLoadOp loadOp,
239239

240240
Value pred = builder.createWithStage<arith::ConstantIntOp>(loc, stage,
241241
clusterId, 1, 1);
242+
Value tmaPtr =
243+
builder.createWithStage<triton::nvidia_gpu::TensorDescToTMAPtrOp>(
244+
loc, stage, clusterId, loadOp.getDesc());
242245
Operation *copy = builder.createWithStage<ttng::AsyncTMACopyGlobalToLocalOp>(
243-
loc, stage, clusterId, loadOp.getDescPtr(), loadOp.getIndices(), barrier,
244-
view, pred);
246+
loc, stage, clusterId, tmaPtr, loadOp.getIndices(), barrier, view, pred);
245247

246248
bool isMMV3Load = loadToInfo[loadOp].loadIsMMAV3;
247249

lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,10 @@ static void createTMAAsyncCopy(scf::ForOp &forOp,
6363
builder.create<ttng::TMAStoreWait>(loc, 0);
6464
builder.create<ttg::LocalStoreOp>(loc, storeOp.getSrc(), alloc);
6565
builder.create<ttng::FenceAsyncSharedOp>(loc, false);
66+
Value tmaPtr = builder.create<triton::nvidia_gpu::TensorDescToTMAPtrOp>(
67+
loc, storeOp.getDesc());
6668
builder.create<ttng::AsyncTMACopyLocalToGlobalOp>(
67-
loc, storeOp.getDescPtr(), storeOp.getIndices(), alloc);
69+
loc, tmaPtr, storeOp.getIndices(), alloc);
6870

6971
storeOp->erase();
7072
}

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,18 @@ void WaitBarrierOp::getEffects(
160160
mlir::triton::gpu::SharedMemory::get());
161161
}
162162

163+
// -- TensorDescToTMAPtrOp --
164+
LogicalResult TensorDescToTMAPtrOp::canonicalize(TensorDescToTMAPtrOp op,
165+
PatternRewriter &rewriter) {
166+
// tensor_desc_to_tma_ptr(reinterpret_tensor_desc(ptr)) -> ptr
167+
if (auto reinterpret =
168+
op.getDesc().getDefiningOp<triton::ReinterpretTensorDescOp>()) {
169+
rewriter.replaceOp(op, reinterpret.getRawDesc());
170+
return success();
171+
}
172+
return failure();
173+
}
174+
163175
// -- AsyncTMACopyGlobalToLocalOp --
164176
LogicalResult AsyncTMACopyGlobalToLocalOp::verify() {
165177
if (failed(verifyBarrierType(*this, getBarrier().getType())))

lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,10 @@ class TMALoadLowering : public OpRewritePattern<ExperimentalDescriptorLoadOp> {
6060
Value pred = rewriter.create<arith::ConstantIntOp>(loc, 1, 1);
6161
rewriter.create<triton::nvidia_gpu::BarrierExpectOp>(loc, barrierAlloc,
6262
sizeInBytes, pred);
63+
Value tmaPtr = rewriter.create<triton::nvidia_gpu::TensorDescToTMAPtrOp>(
64+
loc, op.getDesc());
6365
rewriter.create<triton::nvidia_gpu::AsyncTMACopyGlobalToLocalOp>(
64-
loc, op.getDescPtr(), op.getIndices(), barrierAlloc, alloc, pred);
66+
loc, tmaPtr, op.getIndices(), barrierAlloc, alloc, pred);
6567
Value phase = rewriter.create<arith::ConstantIntOp>(loc, 0, 32);
6668
rewriter.create<WaitBarrierOp>(loc, barrierAlloc, phase);
6769
rewriter.create<InvalBarrierOp>(loc, barrierAlloc);
@@ -95,8 +97,10 @@ class TMAStoreLowering
9597
encoding, sharedMemorySpace, /*mutableMemory=*/true);
9698
Value alloc = rewriter.create<LocalAllocOp>(loc, memDescType, op.getSrc());
9799
rewriter.create<triton::nvidia_gpu::FenceAsyncSharedOp>(loc, false);
100+
Value tmaPtr = rewriter.create<triton::nvidia_gpu::TensorDescToTMAPtrOp>(
101+
loc, op.getDesc());
98102
rewriter.create<triton::nvidia_gpu::AsyncTMACopyLocalToGlobalOp>(
99-
loc, op.getDescPtr(), op.getIndices(), alloc);
103+
loc, tmaPtr, op.getIndices(), alloc);
100104
rewriter.create<triton::nvidia_gpu::TMAStoreWait>(loc, 0);
101105
rewriter.eraseOp(op);
102106
return success();
@@ -194,7 +198,9 @@ class TMACreateDescLowering : public OpRewritePattern<MakeTensorDescOp> {
194198
/*fill_mode=*/rewriter.getI32IntegerAttr(0));
195199
rewriter.create<triton::ExperimentalTensormapFenceproxyAcquireOp>(
196200
loc, alloc.getResult());
197-
rewriter.replaceOp(op, alloc);
201+
auto newDesc = rewriter.create<triton::ReinterpretTensorDescOp>(
202+
loc, op.getType(), alloc.getResult());
203+
rewriter.replaceOp(op, newDesc);
198204
return success();
199205
}
200206
};

0 commit comments

Comments
 (0)