diff --git a/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp b/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp index ce5ab690..4585dad4 100644 --- a/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp +++ b/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp @@ -8,6 +8,7 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton-shared/Conversion/StructuredToMemref/StructuredToMemref.h" +#include "triton-shared/Dialect/TPtr/IR/TPtrDialect.h" #include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" #include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" @@ -24,12 +25,7 @@ #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" -#include "mlir/Pass/PassManager.h" #include "triton/Dialect/Triton/IR/Types.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/Support/Casting.h" - -#include #define DEBUG_TYPE "structured-to-memref" @@ -45,56 +41,19 @@ namespace triton { namespace { -class LoopTypeConverter : public TypeConverter { +class PtrToUnrankedMemrefConverter : public TypeConverter { public: - LoopTypeConverter(MLIRContext *context) { - // The order of type conversion is important: later ones are tried earlier. + PtrToUnrankedMemrefConverter() { addConversion([](Type type) { return type; }); - // addConversion([context](triton::PointerType ptrType) { - // SmallVector strides{1}; - // auto layout = - // StridedLayoutAttr::get(context, ShapedType::kDynamic, strides); - - // auto elemType = ptrType.getPointeeType(); - // auto memrefType = MemRefType::get({1}, elemType, layout); - // return memrefType; - // }); - - // A tensor of pointers can be passed in as scf.for's init-args, in such - // cases, we convert the type to a memref with dynamic offsets and - // strides. - addConversion( - [context](RankedTensorType tensorType) -> std::optional { - if (auto ptrType = llvm::dyn_cast( - tensorType.getElementType())) { - auto layout = StridedLayoutAttr::get( - context, ShapedType::kDynamic, - SmallVector(tensorType.getRank(), - ShapedType::kDynamic)); - Type elemType = ptrType.getPointeeType(); - return MemRefType::get(tensorType.getShape(), elemType, layout); - } - - return std::nullopt; - }); - - // Convert the current memref type to a memref type with dynamic offsets and - // strides through another reinterpret_cast with the same offsets. - // Canonicalization will simplify this sequence by removing the inital - // reinterpret_cast. - addTargetMaterialization([&](OpBuilder &builder, MemRefType memrefType, + addConversion([](triton::PointerType ptrType) { + return UnrankedMemRefType::get(ptrType.getPointeeType(), 0); + }); + addTargetMaterialization([&](OpBuilder &builder, + UnrankedMemRefType resultType, ValueRange inputs, Location loc) -> Value { - auto reinterpretCast = - inputs[0].getDefiningOp(); - if (!reinterpretCast) { - return builder - .create(loc, memrefType, inputs) - .getResult(0); - } - return builder.create( - loc, memrefType, inputs[0], reinterpretCast.getMixedOffsets()[0], - reinterpretCast.getMixedSizes(), reinterpretCast.getMixedStrides()); + return builder.create(loc, resultType, inputs) + .getResult(0); }); addSourceMaterialization([&](OpBuilder &builder, Type resultType, @@ -113,34 +72,18 @@ class LoopTypeConverter : public TypeConverter { } }; -class PtrToUnrankedMemrefConverter : public TypeConverter { -public: - PtrToUnrankedMemrefConverter() { - addConversion([](Type type) { return type; }); - addConversion([](triton::PointerType ptrType) { - return UnrankedMemRefType::get(ptrType.getPointeeType(), 0); - }); - addTargetMaterialization([&](OpBuilder &builder, - UnrankedMemRefType resultType, - ValueRange inputs, - Location loc) -> Value { - return builder.create(loc, resultType, inputs) - .getResult(0); - }); - } -}; - class StructuredToMemrefPass : public triton::impl::StructuredToMemrefBase { using StructuredToMemrefBase::StructuredToMemrefBase; public: void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry + .insert(); } void runOnOperation() override { @@ -165,11 +108,6 @@ class StructuredToMemrefPass triton::populateStructuredToMemrefConversionPatterns(patterns, typeConverter); - LoopTypeConverter loopTypeConverter(patterns.getContext()); - - mlir::scf::populateSCFStructuralTypeConversionsAndLegality( - loopTypeConverter, patterns, target); - if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { signalPassFailure(); }