Skip to content

Commit 4969757

Browse files
committed
Update StructuredToMemref pass
1 parent a7ffd7d commit 4969757

39 files changed

+685
-551
lines changed

include/triton-shared/Conversion/StructuredToMemref/Passes.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@ include "mlir/Pass/PassBase.td"
55

66
def StructuredToMemref : Pass<"structured-to-memref", "mlir::ModuleOp"> {
77
let summary = "Convert triton structured pointer ops to memref";
8+
let options = [
9+
Option<"convertArgsOnly", "convert-args-only", "bool", /*default*/"false",
10+
"Only run the pre-processing pass which inserts tts.get_structured_state ops used in scf.for">,
11+
Option<"noConvertArgs", "no-convert-args", "bool", /*default*/"false",
12+
"Only run the pre-processing pass which inserts tts.get_structured_state ops used in scf.for">
13+
];
814
}
915

1016
#endif

lib/Conversion/StructuredToMemref/StructuredToMemref.cpp

Lines changed: 56 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55
//
66
//===----------------------------------------------------------------------===//
77

8+
#include "triton/Dialect/Triton/IR/Types.h"
9+
10+
#include "triton-shared/Analysis/OpFoldResultUtils.h"
811
#include "triton-shared/Conversion/StructuredToMemref/StructuredToMemref.h"
9-
#include "mlir/Dialect/Arith/IR/Arith.h"
10-
#include "mlir/Dialect/SCF/IR/SCF.h"
12+
#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h"
13+
1114
#include "mlir/IR/Builders.h"
1215
#include "mlir/IR/BuiltinOps.h"
1316
#include "mlir/IR/BuiltinTypeInterfaces.h"
@@ -18,22 +21,17 @@
1821
#include "mlir/IR/Types.h"
1922
#include "mlir/Support/LogicalResult.h"
2023
#include "mlir/Transforms/DialectConversion.h"
21-
#include "triton-shared/Analysis/OpFoldResultUtils.h"
22-
#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h"
23-
24-
#include "triton/Dialect/Triton/IR/Dialect.h"
2524

26-
#include "mlir/Dialect/Affine/IR/AffineOps.h"
25+
#include "mlir/Dialect/Arith/IR/Arith.h"
2726
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
2827
#include "mlir/Dialect/Linalg/IR/Linalg.h"
2928
#include "mlir/Dialect/MemRef/IR//MemRef.h"
30-
#include "triton/Dialect/Triton/IR/Types.h"
31-
29+
#include "mlir/Dialect/SCF/IR/SCF.h"
3230
#include "mlir/Dialect/Utils/StaticValueUtils.h"
31+
3332
#include "llvm/ADT/ArrayRef.h"
3433
#include "llvm/ADT/STLExtras.h"
3534
#include "llvm/ADT/SmallVector.h"
36-
#include "llvm/Support/Debug.h"
3735

3836
#include <algorithm>
3937
#include <cassert>
@@ -61,6 +59,13 @@ static memref::SubViewOp getSubview(int rank, ArrayRef<OpFoldResult> dims,
6159
offsets, dims, strides);
6260
}
6361

62+
static Value getPtr(Value v) {
63+
while (auto op = v.getDefiningOp()) {
64+
v = op->getOperand(0);
65+
}
66+
return v;
67+
}
68+
6469
namespace {
6570

6671
struct MakeTensorPtrConverter
@@ -373,28 +378,27 @@ struct MakeTensorPtrConverter
373378
op, staticTargetOffset.value_or(ShapedType::kDynamic), staticStrides,
374379
resultShape);
375380

376-
// The base ptr, which is from one of the args, would have already been
377-
// converted to memref<*> at this point, so get the base from adaptor.
378-
//
379381
// For block pointers, the base could come from a sequence of `tt.addptr`,
380-
// which at this point has already been lowered to a sequence of
381-
// `memref.reinterpret_cast` ops. The offset in such cases are dynamic.
382-
// (see test/Conversion/StructuredToMemref/block_ptr_complex_offset.mlir)
383-
//
384-
// For non-block pointer cases, the base is the reinterpret_cast of a
385-
// function argument. Assert that the offset is a constant 0 in such cases.
386-
auto ptr = adaptor.getBase();
387-
if (auto reinterpretCast = ptr.getDefiningOp<memref::ReinterpretCastOp>()) {
388-
auto offset = reinterpretCast.getMixedOffsets()[0];
389-
auto intAttr = getIntAttr(offset);
390-
assert(isBlockPtr || (intAttr.has_value() && intAttr.value() == 0));
391-
targetOffset = addOFRs(targetOffset, reinterpretCast.getMixedOffsets()[0],
392-
op->getLoc(), rewriter);
382+
// which at this point has already been lowered a single
383+
// tts.make_unstructured_tptr via --fold-unstructured-ptr. The offset in
384+
// such cases is dynamic and comes from the second operand of
385+
// tts.make_unstructured_tptr. See
386+
// test/Conversion/StructuredToMemref/block_ptr_complex_offset.mlir
387+
// Accumulate the target offset with the offset from
388+
// tts.make_unstructured_tptr.
389+
if (auto makePtr =
390+
op.getBase().getDefiningOp<tts::MakeUnstructuredTensorPtrOp>()) {
391+
auto prevOff =
392+
rewriter
393+
.create<arith::IndexCastOp>(op.getLoc(), rewriter.getIndexType(),
394+
makePtr.getOffset())
395+
.getResult();
396+
targetOffset = addOFRs(prevOff, targetOffset, op->getLoc(), rewriter);
393397
}
394398

395399
auto castOp = rewriter.create<memref::ReinterpretCastOp>(
396-
op.getLoc(), resultType, ptr, targetOffset, op.getMixedSizes(),
397-
mixedStrides);
400+
op.getLoc(), resultType, adaptor.getBase(), targetOffset,
401+
op.getMixedSizes(), mixedStrides);
398402

399403
rewriter.replaceOp(op, castOp);
400404

@@ -421,6 +425,10 @@ struct MakeTensorPtrConverter
421425
}
422426

423427
public:
428+
MakeTensorPtrConverter(const TypeConverter &typeConverter,
429+
MLIRContext *context)
430+
: OpConversionPattern<tts::MakeTensorPtrOp>(typeConverter, context) {}
431+
424432
LogicalResult
425433
matchAndRewrite(tts::MakeTensorPtrOp op, OpAdaptor adaptor,
426434
ConversionPatternRewriter &rewriter) const override {
@@ -594,8 +602,13 @@ struct LoadConverter : public OpConversionPattern<tts::LoadOp> {
594602
// No mask
595603
assert(!other && "other value used in non-masked load");
596604

597-
if (auto unrealizedCast = ptr.getDefiningOp<UnrealizedConversionCastOp>()) {
605+
auto ptrDefiningOp = ptr.getDefiningOp();
606+
if (ptrDefiningOp->hasAttr(WRAP_SIDE_BY_SIDE) ||
607+
ptrDefiningOp->hasAttr(WRAP_STACKED)) {
608+
609+
auto unrealizedCast = cast<UnrealizedConversionCastOp>(ptrDefiningOp);
598610
auto memrefs = unrealizedCast.getOperands();
611+
assert(memrefs.size() == 2);
599612
auto block1 = memrefs[0];
600613
auto block2 = memrefs[1];
601614

@@ -664,9 +677,14 @@ struct LoadConverter : public OpConversionPattern<tts::LoadOp> {
664677
});
665678
}
666679

667-
if (auto unrealizedCast = ptr.getDefiningOp<UnrealizedConversionCastOp>()) {
680+
auto ptrDefiningOp = ptr.getDefiningOp();
681+
if (ptrDefiningOp->hasAttr(WRAP_SIDE_BY_SIDE) ||
682+
ptrDefiningOp->hasAttr(WRAP_STACKED)) {
683+
684+
auto unrealizedCast = cast<UnrealizedConversionCastOp>(ptrDefiningOp);
668685

669686
auto memrefs = unrealizedCast.getOperands();
687+
assert(memrefs.size() == 2);
670688
auto block1 = memrefs[0];
671689
auto block2 = memrefs[1];
672690

@@ -700,6 +718,9 @@ struct LoadConverter : public OpConversionPattern<tts::LoadOp> {
700718
}
701719

702720
public:
721+
LoadConverter(const TypeConverter &typeConverter, MLIRContext *context)
722+
: OpConversionPattern<tts::LoadOp>(typeConverter, context) {}
723+
703724
LogicalResult
704725
matchAndRewrite(tts::LoadOp op, OpAdaptor adaptor,
705726
ConversionPatternRewriter &rewriter) const override {
@@ -730,6 +751,9 @@ struct StoreConverter : public OpConversionPattern<tts::StoreOp> {
730751
}
731752

732753
public:
754+
StoreConverter(const TypeConverter &typeConverter, MLIRContext *context)
755+
: OpConversionPattern<tts::StoreOp>(typeConverter, context) {}
756+
733757
LogicalResult
734758
matchAndRewrite(tts::StoreOp op, OpAdaptor adaptor,
735759
ConversionPatternRewriter &rewriter) const override {
@@ -759,101 +783,10 @@ struct StoreConverter : public OpConversionPattern<tts::StoreOp> {
759783
}
760784
};
761785

762-
struct ScalarLoadConverter : public OpConversionPattern<triton::LoadOp> {
763-
using OpConversionPattern<triton::LoadOp>::OpConversionPattern;
764-
765-
LogicalResult
766-
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
767-
ConversionPatternRewriter &rewriter) const override {
768-
if (!op.getType().isIntOrIndexOrFloat()) {
769-
return failure();
770-
}
771-
772-
auto loc = op->getLoc();
773-
auto memrefPtr = adaptor.getPtr();
774-
auto zeroMap = AffineMap::getConstantMap(0, rewriter.getContext());
775-
auto loadOp = rewriter.create<affine::AffineLoadOp>(loc, memrefPtr, zeroMap,
776-
std::nullopt);
777-
rewriter.replaceOp(op, loadOp.getResult());
778-
779-
return success();
780-
}
781-
};
782-
783-
struct ScalarStoreConverter : public OpConversionPattern<triton::StoreOp> {
784-
private:
785-
using OpConversionPattern<triton::StoreOp>::OpConversionPattern;
786-
787-
public:
788-
LogicalResult
789-
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
790-
ConversionPatternRewriter &rewriter) const override {
791-
792-
if (!op.getValue().getType().isIntOrIndexOrFloat()) {
793-
return failure();
794-
}
795-
796-
auto loc = op->getLoc();
797-
auto memrefPtr = adaptor.getPtr();
798-
auto val = op.getValue();
799-
auto zeroMap = AffineMap::getConstantMap(0, rewriter.getContext());
800-
801-
rewriter.create<affine::AffineStoreOp>(loc, val, memrefPtr, zeroMap,
802-
std::nullopt);
803-
rewriter.eraseOp(op);
804-
805-
return success();
806-
}
807-
};
808-
809-
struct UnrealizedCastConverter
810-
: public OpConversionPattern<UnrealizedConversionCastOp> {
811-
private:
812-
using OpConversionPattern<UnrealizedConversionCastOp>::OpConversionPattern;
813-
814-
public:
815-
UnrealizedCastConverter(TypeConverter &typeConverter, MLIRContext *context)
816-
: OpConversionPattern<UnrealizedConversionCastOp>(typeConverter,
817-
context) {}
818-
819-
LogicalResult
820-
matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor,
821-
ConversionPatternRewriter &rewriter) const override {
822-
auto resType = op->getResultTypes()[0];
823-
auto input = op.getInputs()[0];
824-
auto inputType = input.getType();
825-
826-
if (!isa<triton::PointerType>(resType) ||
827-
!isa<MemRefType, UnrankedMemRefType>(inputType)) {
828-
return failure();
829-
}
830-
831-
if (auto reinterpretCast =
832-
input.getDefiningOp<memref::ReinterpretCastOp>()) {
833-
rewriter.replaceOp(op, reinterpretCast);
834-
} else {
835-
auto ptrType = cast<triton::PointerType>(resType);
836-
auto memrefType =
837-
cast<MemRefType>(getTypeConverter()->convertType(ptrType));
838-
839-
auto cast = rewriter.create<memref::ReinterpretCastOp>(
840-
op->getLoc(), memrefType, op.getInputs()[0], 0 /*offset*/,
841-
SmallVector<int64_t>{1} /*sizes*/,
842-
SmallVector<int64_t>{1} /*strides*/);
843-
844-
rewriter.replaceOp(op, cast);
845-
}
846-
847-
return success();
848-
}
849-
};
850-
851786
} // namespace
852787

853788
void mlir::triton::populateStructuredToMemrefConversionPatterns(
854789
RewritePatternSet &patterns, TypeConverter &typeConverter) {
855-
patterns.add<UnrealizedCastConverter>(typeConverter, patterns.getContext());
856-
patterns.add<MakeTensorPtrConverter, LoadConverter, StoreConverter,
857-
ScalarLoadConverter, ScalarStoreConverter>(
858-
patterns.getContext());
790+
patterns.add<MakeTensorPtrConverter>(typeConverter, patterns.getContext());
791+
patterns.add<LoadConverter, StoreConverter>(patterns.getContext());
859792
}

0 commit comments

Comments
 (0)