Skip to content

Commit b6485a7

Browse files
Merge commit 'fe8ee0c2dff9036de6b27e91fc4d02eb0fdbf925'
2 parents 328fd8a + fe8ee0c commit b6485a7

File tree

27 files changed

+809
-82
lines changed

27 files changed

+809
-82
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,6 @@ def TTG_MemDescTransOp : TTG_Op<"memdesc_trans", [Pure,
252252
representing a transposed view of the buffer.
253253
}];
254254

255-
let arguments = (ins TTG_MemDescType:$src, Variadic<I32>:$order);
256-
257255
let arguments = (
258256
ins TTG_MemDescType:$src,
259257
DenseI32ArrayAttr:$order
@@ -284,6 +282,26 @@ def TTG_MemDescReshapeOp : TTG_Op<"memdesc_reshape", [Pure,
284282
let hasVerifier = 1;
285283
}
286284

285+
def TTG_MemDescReinterpretOp : TTG_Op<"memdesc_reinterpret", [Pure, MemDescViewTrait]> {
286+
let summary = "reinterpret a memory descriptor as a different type and shape";
287+
288+
let description = [{
289+
The `ttg.memdesc_reinterpret` operation reinterprets a memory descriptor
290+
as one with a different shape and element type. Because memory descriptors
291+
lack strides, this operation is only valid if the original memory descriptor
292+
is contiguous.
293+
}];
294+
295+
let arguments = (ins TTG_MemDescType:$src);
296+
let results = (outs TTG_MemDescType:$result);
297+
298+
let assemblyFormat = [{
299+
$src attr-dict `:` qualified(type($src)) `->` qualified(type($result))
300+
}];
301+
302+
let hasVerifier = 1;
303+
}
304+
287305
def TTG_LocalLoadOp : TTG_Op<"local_load"> {
288306
let summary = "Load a buffer from local memory into a distributed tensor";
289307

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,7 @@ def TritonGPUCanonicalize: Pass<"tritongpu-canonicalize"> {
369369
}];
370370
let dependentDialects = [
371371
"mlir::arith::ArithDialect",
372+
"mlir::cf::ControlFlowDialect",
372373
"mlir::scf::SCFDialect",
373374
];
374375
}

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,27 @@ struct MemDescSubviewOpConversion
480480
return success();
481481
}
482482
};
483+
484+
struct MemDescReinterpretOpConversion
485+
: public ConvertOpToLLVMPattern<MemDescReinterpretOp> {
486+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
487+
488+
LogicalResult matchAndRewrite(MemDescReinterpretOp op, OpAdaptor adaptor,
489+
ConversionPatternRewriter &b) const override {
490+
Location loc = op.getLoc();
491+
MemDescType srcTy = op.getSrc().getType();
492+
MemDescType dstTy = op.getType();
493+
Type srcElemTy = getTypeConverter()->convertType(srcTy.getElementType());
494+
Type dstElemTy = getTypeConverter()->convertType(dstTy.getElementType());
495+
496+
auto smemObj =
497+
getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), srcElemTy, b);
498+
SharedMemoryObject newObj(smemObj.getBase(), dstElemTy, dstTy.getRank(),
499+
loc, b);
500+
b.replaceOp(op, getStructFromSharedMemoryObject(loc, newObj, b));
501+
return success();
502+
}
503+
};
483504
} // namespace
484505

485506
void mlir::triton::populateViewOpToLLVMPatterns(
@@ -497,4 +518,5 @@ void mlir::triton::populateViewOpToLLVMPatterns(
497518
patterns.add<TransOpConversion>(typeConverter, benefit);
498519
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
499520
patterns.add<MemDescSubviewOpConversion>(typeConverter, benefit);
521+
patterns.add<MemDescReinterpretOpConversion>(typeConverter, benefit);
500522
}

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -439,14 +439,20 @@ MemDescTransOp::inferReturnTypes(MLIRContext *context,
439439
return failure();
440440
}
441441
}
442+
443+
// Permute the last `rank` dims of the source alloc shape.
444+
SmallVector<int64_t> allocShape =
445+
applyPermutation(argTy.getAllocShape().take_back(order.size()), order);
446+
allocShape.insert(allocShape.begin(), argTy.getAllocShape().begin(),
447+
argTy.getAllocShape().end() - order.size());
448+
442449
inferredReturnTypes.push_back(
443450
MemDescType::get(retShape, retEltTy, retEncoding, argTy.getMemorySpace(),
444-
argTy.getMutableMemory()));
451+
argTy.getMutableMemory(), allocShape));
445452
return success();
446453
}
447454

448455
// MemDescReshapeOp
449-
450456
LogicalResult MemDescReshapeOp::verify() {
451457
MemDescType dstType = getResult().getType();
452458
MemDescType srcType = getSrc().getType();
@@ -472,6 +478,13 @@ LogicalResult MemDescReshapeOp::verify() {
472478
return success();
473479
}
474480

481+
// MemDescReinterpretOp
482+
LogicalResult MemDescReinterpretOp::verify() {
483+
if (getSrc().getType().getMemorySpace() != getType().getMemorySpace())
484+
return emitError("source and destination memory space must match");
485+
return success();
486+
}
487+
475488
// LocalAllocOp
476489
void LocalAllocOp::getEffects(
477490
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
@@ -623,20 +636,15 @@ LogicalResult MemDescSubviewOp::verify() {
623636
"only nD -> (n-1)D rank-reducing subviews are supported");
624637
}
625638
for (auto offset : getOffsets().take_back(dstTy.getRank())) {
626-
if (auto constOp = offset.getDefiningOp<arith::ConstantOp>()) {
627-
if (auto offsetInt = dyn_cast<IntegerAttr>(constOp.getValue())) {
628-
if (offsetInt.getInt() != 0) {
629-
return emitError("only first offset can be non-zero for a "
630-
"rank-reducing subview");
631-
}
632-
} else {
633-
return emitError(
634-
"only integer constant values are allowed for the split");
635-
}
636-
} else {
639+
APInt value;
640+
if (!matchPattern(offset, m_ConstantInt(&value))) {
637641
return emitError("only constant values are allowed outside the front "
638642
"dimension in a rank-reducing subview");
639643
}
644+
if (!value.isZero()) {
645+
return emitError(
646+
"only first offset can be non-zero for a rank-reducing subview");
647+
}
640648
}
641649
return success();
642650
}
@@ -658,16 +666,10 @@ LogicalResult MemDescSubviewOp::verify() {
658666
}
659667
SmallVector<int64_t> offsets;
660668
for (auto offset : getOffsets()) {
661-
if (auto constOp = offset.getDefiningOp<arith::ConstantOp>()) {
662-
if (auto offsetInt = dyn_cast<IntegerAttr>(constOp.getValue())) {
663-
offsets.push_back(offsetInt.getInt());
664-
} else {
665-
return emitError(
666-
"only integer constant values are allowed for the split");
667-
}
668-
} else {
669+
APInt value;
670+
if (!matchPattern(offset, m_ConstantInt(&value)))
669671
return emitError("only constant values are allowed for the split");
670-
}
672+
offsets.push_back(value.getSExtValue());
671673
}
672674
// Identity subview
673675
if (dim == -1) {

lib/Dialect/TritonGPU/Transforms/Canonicalize.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "mlir/Dialect/Arith/IR/Arith.h"
2+
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
23
#include "mlir/Dialect/SCF/IR/SCF.h"
34
#include "mlir/Pass/Pass.h"
45
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -32,6 +33,8 @@ void Canonicalize::runOnOperation() {
3233
patterns);
3334
ctx->getLoadedDialect<scf::SCFDialect>()->getCanonicalizationPatterns(
3435
patterns);
36+
ctx->getLoadedDialect<cf::ControlFlowDialect>()->getCanonicalizationPatterns(
37+
patterns);
3538
populateForOpDeadArgumentElimination(patterns);
3639

3740
// Populate select Triton canonicalization patterns. The important patterns to
@@ -43,4 +46,6 @@ void Canonicalize::runOnOperation() {
4346
ExpandDimsOp::getCanonicalizationPatterns(patterns, ctx);
4447
ttg::WarpSpecializeOp::getCanonicalizationPatterns(patterns, ctx);
4548
ttng::TensorDescToTMAPtrOp::getCanonicalizationPatterns(patterns, ctx);
49+
50+
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
4651
}

lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "mlir/IR/DialectImplementation.h"
3131
#include "mlir/IR/OpImplementation.h"
3232
#include "triton/Analysis/Utility.h"
33+
#include "triton/Dialect/Triton/IR/Interfaces.h"
3334
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
3435
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
3536
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
@@ -264,6 +265,7 @@ void TritonNvidiaGPUDialect::initialize() {
264265
#include "triton/Dialect/TritonNvidiaGPU/IR/Ops.cpp.inc"
265266
>();
266267
addInterfaces<TritonGPUOpAsmInterface>();
268+
addInterfaces<TritonInlinerInterface>();
267269
}
268270

269271
// verify TritonNvidiaGPU ops

lib/Dialect/TritonNvidiaGPU/Transforms/TensorMemoryAllocation.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ namespace mlir {
1313
namespace triton {
1414
namespace nvidia_gpu {
1515

16+
namespace ttg = triton::gpu;
17+
1618
#define GEN_PASS_DEF_TRITONTENSORMEMORYALLOCATIONPASS
1719
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"
1820

@@ -118,7 +120,7 @@ static Interval<int> getLiveIntervals(Value value, Liveness &liveness,
118120
SmallVector<Operation *> users(value.getUsers());
119121
while (!users.empty()) {
120122
Operation *user = users.pop_back_val();
121-
if (!isa<triton::gpu::MemDescSubviewOp>(user))
123+
if (!isa<ttg::MemDescSubviewOp, ttg::MemDescReinterpretOp>(user))
122124
continue;
123125
auto usersLivness = liveness.resolveLiveness(user->getResult(0));
124126
liveOperations.insert(liveOperations.end(), usersLivness.begin(),
@@ -177,10 +179,14 @@ static Operation *getAlloc(Value value) {
177179
while (true) {
178180
if (auto allocOp = value.getDefiningOp<TMEMAllocOp>())
179181
return allocOp;
180-
if (auto subviewOp = value.getDefiningOp<triton::gpu::MemDescSubviewOp>()) {
182+
if (auto subviewOp = value.getDefiningOp<ttg::MemDescSubviewOp>()) {
181183
value = subviewOp.getSrc();
182184
continue;
183185
}
186+
if (auto reinterpOp = value.getDefiningOp<ttg::MemDescReinterpretOp>()) {
187+
value = reinterpOp.getSrc();
188+
continue;
189+
}
184190
auto arg = dyn_cast<BlockArgument>(value);
185191
if (!arg || !isa<triton::gpu::WarpSpecializePartitionsOp>(
186192
arg.getOwner()->getParentOp()))

0 commit comments

Comments
 (0)