Skip to content

Commit 7d4f1ce

Browse files
Merge commit '1cf06c5e1982eba8f17062e1c6c3d3fa458597b2'
2 parents c17a0fb + 1cf06c5 commit 7d4f1ce

File tree

44 files changed

+1057
-701
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+1057
-701
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,9 @@ For detailed instructions on how to debug Triton's frontend, please refer to thi
176176
kernels. Use `MLIR_ENABLE_DUMP=kernelName` to dump for a specific kernel only.
177177
- Triton cache can interfere with the dump. In cases where `MLIR_ENABLE_DUMP=1` does not work, try cleaning your triton cache: `rm -r ~/.triton/cache/*`
178178
- `LLVM_IR_ENABLE_DUMP=1` dumps the IR before every pass run over the LLVM IR.
179+
- `TRITON_REPRODUCER_PATH=<reproducer_path>` will generate an MLIR reproducer file
180+
at `<reproducer_path>` before each MLIR compiler stage. If any of the stages fail,
181+
`<reproducer_path>` will be a local MLIR reproducer captured right before the failing pass.
179182
- `TRITON_INTERPRET=1` uses the Triton interpreter instead of running on the
180183
GPU. You can insert Python breakpoints in your kernel code!
181184
- `TRITON_ENABLE_LLVM_DEBUG=1` passes `-debug` to LLVM, printing a lot of

cmake/AddTritonUnitTest.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,5 @@ function(add_triton_ut)
3535
# Without the TEST_DISCOVERY_TIMEOUT, the tests randomly time out on my mac
3636
# laptop. I think the issue may be that the very first time you run a program
3737
# it's a bit slow.
38-
gtest_discover_tests(${__NAME} PROPERTIES TEST_DISCOVERY_TIMEOUT 60)
38+
gtest_discover_tests(${__NAME} DISCOVERY_TIMEOUT 60)
3939
endfunction()

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -956,9 +956,10 @@ def TT_MakeTensorPtrOp : TT_Op<"make_tensor_ptr",
956956
//
957957
// Make Tensor Descriptor Op
958958
//
959-
def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor",
960-
[Pure,
961-
SameVariadicOperandSize]> {
959+
def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor", [
960+
Pure,
961+
SameVariadicOperandSize,
962+
]> {
962963
let summary = "Make a tensor descriptor type with meta information of the parent tensor and block size";
963964

964965
let description = [{
@@ -969,23 +970,38 @@ def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor",
969970
let arguments = (ins
970971
TT_Ptr:$base,
971972
Variadic<I32>:$shape,
972-
Variadic<I64>:$strides,
973-
DenseI32ArrayAttr:$tensorShape
973+
Variadic<I64>:$strides
974974
);
975975

976-
// TODO(peterbell10): define a custom IR type to represent descriptors
977-
let results = (outs TT_Ptr:$result);
976+
let results = (outs TT_TensorDescType:$result);
978977

979978
let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` attr-dict `:` type($base) `,` type($result)";
980979

981980
let builders = [
982-
OpBuilder<(ins
983-
"Value":$base,
984-
"ValueRange":$shape,
985-
"ValueRange":$strides,
986-
"ArrayRef<int32_t>":$tensorShape
987-
)>
981+
OpBuilder<(ins "Value":$base, "ValueRange":$shape, "ValueRange":$strides, "ArrayRef<int32_t>":$blockShape)>
988982
];
983+
984+
let extraClassDeclaration = [{
985+
ArrayRef<int64_t> getTensorShape() {
986+
return getType().getBlockType().getShape();
987+
}
988+
}];
989+
}
990+
991+
def ReinterpretTensorDescOp : TT_Op<"reinterpret_tensor_descriptor", [Pure]> {
992+
let summary = "Reinterpret a pointer as a tensor descriptor";
993+
994+
let description = [{
995+
This Op exists to help the transition from untyped raw TMA objects to typed Tensor descriptor objects.
996+
Ideally, we can remove this once the APIs are fully fleshed out.
997+
}];
998+
999+
let arguments = (ins TT_Ptr:$rawDesc);
1000+
let results = (outs TT_TensorDescType:$result);
1001+
1002+
let assemblyFormat = [{
1003+
$rawDesc attr-dict `:` qualified(type($rawDesc)) `to` qualified(type($result))
1004+
}];
9891005
}
9901006

9911007
// The following ops, including `call`, `func`, and `return` are copied and modified from
@@ -1195,20 +1211,19 @@ def ReturnOp : TT_Op<"return", [Pure, HasParent<"FuncOp">, /*MemRefsNormalizable
11951211
}
11961212

11971213

1198-
def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [
1199-
MemoryEffects<[MemRead<GlobalMemory>]>]> {
1214+
def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [MemoryEffects<[MemRead<GlobalMemory>]>]> {
12001215
let summary = "Load from descriptor";
12011216
let description = [{
12021217
This operation will be lowered to Nvidia TMA load operation on targets supporting it.
1203-
`desc_ptr` is a pointer to the TMA descriptor allocated in global memory.
1218+
`desc` is a tensor descriptor object.
12041219
The destination tensor type and shape must match the descriptor otherwise the result is undefined.
12051220

12061221
This is an escape hatch and is only there for testing/experimenting.
12071222
This op will be removed in the future.
12081223
}];
12091224
let arguments = (
12101225
ins
1211-
TT_PtrType:$desc_ptr,
1226+
TT_TensorDescType:$desc,
12121227
Variadic<I32>:$indices,
12131228
DefaultValuedAttr<TT_CacheModifierAttr, "::mlir::triton::CacheModifier::NONE">:$cache,
12141229
DefaultValuedAttr<TT_EvictionPolicyAttr, "::mlir::triton::EvictionPolicy::NORMAL">:$evict
@@ -1217,36 +1232,37 @@ def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [
12171232
let results = (outs TT_Tensor:$result);
12181233

12191234
let assemblyFormat = [{
1220-
$desc_ptr `[` $indices `]`
1235+
$desc `[` $indices `]`
12211236
oilist(
12221237
`cacheModifier` `=` $cache |
12231238
`evictionPolicy` `=` $evict
12241239
)
1225-
attr-dict `:` qualified(type($desc_ptr)) `->` type($result)
1240+
attr-dict `:` qualified(type($desc)) `->` type($result)
12261241
}];
12271242
}
12281243

12291244
def TT_ExperimentalDescriptorStoreOp : TT_Op<"experimental_descriptor_store", [
1230-
MemoryEffects<[MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>]> {
1245+
MemoryEffects<[MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>,
1246+
]> {
12311247
let summary = "store value based on descriptor";
12321248
let description = [{
12331249
This operation will be lowered to Nvidia TMA store operation on targets supporting it.
1234-
`desc_ptr` is a pointer to the TMA descriptor allocated in global memory.
1250+
`desc` is a tensor descriptor object.
12351251
The shape and types of `src` must match the descriptor otherwise the result is undefined.
12361252

12371253
This is an escape hatch and is only there for testing/experimenting.
12381254
This op will be removed in the future.
12391255
}];
12401256
let arguments = (
12411257
ins
1242-
TT_PtrType:$desc_ptr,
1258+
TT_TensorDescType:$desc,
12431259
TT_Tensor:$src,
12441260
Variadic<I32>:$indices
12451261
);
12461262

12471263
let assemblyFormat = [{
1248-
$desc_ptr `[` $indices `]` `,` $src
1249-
attr-dict `:` qualified(type($desc_ptr)) `,` type($src)
1264+
$desc `[` $indices `]` `,` $src
1265+
attr-dict `:` qualified(type($desc)) `,` type($src)
12501266
}];
12511267
}
12521268

include/triton/Dialect/Triton/IR/TritonTypes.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,5 +140,16 @@ def TT_MemDescType : TritonTypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]>
140140
let hasCustomAssemblyFormat = 1;
141141
}
142142

143+
// Result type of ExperimentalMakeTensorDescriptor
144+
def TT_TensorDescType : TritonTypeDef<"TensorDesc", "tensordesc", []> {
145+
let summary = "Tensor descriptor type (`::mlir::triton::TensorDescType`) in Triton IR type system";
146+
147+
let description = [{
148+
A portable abstraction for nvidia-TMA descriptors.
149+
}];
150+
151+
let parameters = (ins "RankedTensorType":$blockType);
152+
let assemblyFormat = "`<` $blockType `>`";
153+
}
143154

144155
#endif

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ def TTG_GlobalScratchAllocOp : TTG_Op<"global_scratch_alloc", [MemoryEffects<[Me
295295
$_builder.getI32IntegerAttr(nbytes), $_builder.getI32IntegerAttr(alignment)); }]>
296296
];
297297

298-
let assemblyFormat = [{attr-dict `:` type($result)}];
298+
let assemblyFormat = [{attr-dict `:` qualified(type($result))}];
299299
}
300300

301301
#endif

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,26 @@ def TTNG_WaitBarrierOp : TTNG_Op<"wait_barrier", [DeclareOpInterfaceMethods<Memo
185185
let assemblyFormat = "$alloc `,` $phase attr-dict `:` type($alloc)";
186186
}
187187

188+
def TTNG_TensorDescToTMAPtrOp : TTNG_Op<"tensor_desc_to_tma_ptr", [Pure]> {
189+
let summary = "Convert tensor descriptor to pointer to tma descriptor";
190+
191+
let arguments = (ins TT_TensorDescType:$desc);
192+
let results = (outs TT_Ptr:$ptr);
193+
194+
let assemblyFormat = [{
195+
$desc attr-dict `:` qualified(type($desc)) `to` qualified(type($ptr))
196+
}];
197+
198+
let builders = [
199+
OpBuilder<(ins "Value":$desc), [{
200+
auto ptrTy = triton::PointerType::get($_builder.getI8Type(), 1);
201+
build($_builder, $_state, ptrTy, desc);
202+
}]>
203+
];
204+
205+
let hasCanonicalizeMethod = 1;
206+
}
207+
188208

189209
def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
190210
let summary = "copy data based on descriptor from global memory to local memory asynchronously";

lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
2828
addConversion([&](MemDescType type) -> std::optional<Type> {
2929
return convertMemDescType(type, targetInfo);
3030
});
31+
addConversion([](TensorDescType type) -> std::optional<Type> {
32+
auto ctx = type.getContext();
33+
return LLVM::LLVMPointerType::get(ctx, 1);
34+
});
3135
addConversion([&](triton::gpu::AsyncTokenType type) -> std::optional<Type> {
3236
return convertAsyncToken(type);
3337
});

lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,10 @@ class ArithConstantPattern : public OpConversionPattern<arith::ConstantOp> {
5959
Type retType = getTypeConverter()->convertType(op.getType());
6060
auto retShapedType = cast<ShapedType>(retType);
6161
auto value = dyn_cast<DenseElementsAttr>(adaptor.getValue());
62-
if (dyn_cast<RankedTensorType>(retShapedType)) {
63-
assert(value);
64-
if (value.getElementType().isInteger(1) && value.isSplat())
65-
// Workaround until https://reviews.llvm.org/D133743 is included.
66-
value =
67-
DenseElementsAttr::get(retShapedType, value.getSplatValue<bool>());
68-
else
69-
// This is a hack. We just want to add encoding
70-
value = value.reshape(retShapedType);
62+
if (isa<RankedTensorType>(retShapedType)) {
63+
assert(value && "expected a dense elements attribute");
64+
// This is a hack. We just want to add encoding.
65+
value = value.reshape(retShapedType);
7166
}
7267
addNamedAttrs(rewriter.replaceOpWithNewOp<arith::ConstantOp>(
7368
op, retShapedType, value),

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "triton/Dialect/Triton/IR/Dialect.h"
99
#include "triton/Dialect/Triton/IR/Types.h"
1010
#include "triton/Dialect/Triton/IR/Utility.h"
11+
#include "llvm/Support/ErrorHandling.h"
1112

1213
namespace mlir {
1314
namespace triton {
@@ -863,12 +864,17 @@ OpFoldResult AdvanceOp::fold(FoldAdaptor adaptor) {
863864
//-- MakeTensorDescOp --
864865
void MakeTensorDescOp::build(OpBuilder &builder, OperationState &state,
865866
Value base, ValueRange shape, ValueRange strides,
866-
ArrayRef<int32_t> tensorShape) {
867-
auto resultTy = getPointerType(builder.getI8Type());
868-
assert(resultTy.getContext());
867+
ArrayRef<int32_t> blockShape) {
868+
auto ptrTy = dyn_cast<triton::PointerType>(base.getType());
869+
if (!ptrTy) {
870+
llvm::report_fatal_error("Expected pointer type");
871+
}
872+
auto elemTy = ptrTy.getPointeeType();
869873

870-
return build(builder, state, resultTy, base, shape, strides,
871-
builder.getDenseI32ArrayAttr(tensorShape));
874+
SmallVector<int64_t> blockShape64(blockShape);
875+
auto blockTy = RankedTensorType::get(blockShape64, elemTy);
876+
auto descTy = TensorDescType::get(builder.getContext(), blockTy);
877+
return build(builder, state, descTy, base, shape, strides);
872878
}
873879

874880
// The following ops, including `call`, `func`, and `return` are copied and

lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@ namespace tt = mlir::triton;
3434
namespace ttg = mlir::triton::gpu;
3535
namespace ttng = mlir::triton::nvidia_gpu;
3636

37-
// TODO: We can extra some helpers into common utilities once we add more
37+
// TODO: We can extract some helpers into common utilities once we add more
3838
// schedules.
3939

4040
namespace {
4141

4242
struct LoadInfo {
43-
// Layout of the data in the shared memory.
43+
// Layout of the data in shared memory.
4444
ttg::SharedEncodingAttr sharedEncoding = nullptr;
4545
// Blocked encoding is used for loads not used by the dot.
4646
ttg::BlockedEncodingAttr blockedEncoding = nullptr;
@@ -239,9 +239,11 @@ createTMAAsyncCopy(scf::ForOp &forOp, tt::ExperimentalDescriptorLoadOp loadOp,
239239

240240
Value pred = builder.createWithStage<arith::ConstantIntOp>(loc, stage,
241241
clusterId, 1, 1);
242+
Value tmaPtr =
243+
builder.createWithStage<triton::nvidia_gpu::TensorDescToTMAPtrOp>(
244+
loc, stage, clusterId, loadOp.getDesc());
242245
Operation *copy = builder.createWithStage<ttng::AsyncTMACopyGlobalToLocalOp>(
243-
loc, stage, clusterId, loadOp.getDescPtr(), loadOp.getIndices(), barrier,
244-
view, pred);
246+
loc, stage, clusterId, tmaPtr, loadOp.getIndices(), barrier, view, pred);
245247

246248
bool isMMV3Load = loadToInfo[loadOp].loadIsMMAV3;
247249

0 commit comments

Comments
 (0)