Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 16 additions & 78 deletions lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "triton/Dialect/Triton/IR/Dialect.h"

#include "triton-shared/Conversion/StructuredToMemref/StructuredToMemref.h"
#include "triton-shared/Dialect/TPtr/IR/TPtrDialect.h"
#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h"
#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h"

Expand All @@ -24,12 +25,7 @@
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Pass/PassManager.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"

#include <optional>

#define DEBUG_TYPE "structured-to-memref"

Expand All @@ -45,56 +41,19 @@ namespace triton {

namespace {

class LoopTypeConverter : public TypeConverter {
class PtrToUnrankedMemrefConverter : public TypeConverter {
public:
LoopTypeConverter(MLIRContext *context) {
// The order of type conversion is important: later ones are tried earlier.
PtrToUnrankedMemrefConverter() {
addConversion([](Type type) { return type; });
// addConversion([context](triton::PointerType ptrType) {
// SmallVector<int64_t> strides{1};
// auto layout =
// StridedLayoutAttr::get(context, ShapedType::kDynamic, strides);

// auto elemType = ptrType.getPointeeType();
// auto memrefType = MemRefType::get({1}, elemType, layout);
// return memrefType;
// });

// A tensor of pointers can be passed in as scf.for's init-args, in such
// cases, we convert the type to a memref with dynamic offsets and
// strides.
addConversion(
[context](RankedTensorType tensorType) -> std::optional<MemRefType> {
if (auto ptrType = llvm::dyn_cast<triton::PointerType>(
tensorType.getElementType())) {
auto layout = StridedLayoutAttr::get(
context, ShapedType::kDynamic,
SmallVector<int64_t>(tensorType.getRank(),
ShapedType::kDynamic));
Type elemType = ptrType.getPointeeType();
return MemRefType::get(tensorType.getShape(), elemType, layout);
}

return std::nullopt;
});

// Convert the current memref type to a memref type with dynamic offsets and
// strides through another reinterpret_cast with the same offsets.
// Canonicalization will simplify this sequence by removing the inital
// reinterpret_cast.
addTargetMaterialization([&](OpBuilder &builder, MemRefType memrefType,
addConversion([](triton::PointerType ptrType) {
return UnrankedMemRefType::get(ptrType.getPointeeType(), 0);
});
addTargetMaterialization([&](OpBuilder &builder,
UnrankedMemRefType resultType,
ValueRange inputs,
Location loc) -> Value {
auto reinterpretCast =
inputs[0].getDefiningOp<memref::ReinterpretCastOp>();
if (!reinterpretCast) {
return builder
.create<UnrealizedConversionCastOp>(loc, memrefType, inputs)
.getResult(0);
}
return builder.create<memref::ReinterpretCastOp>(
loc, memrefType, inputs[0], reinterpretCast.getMixedOffsets()[0],
reinterpretCast.getMixedSizes(), reinterpretCast.getMixedStrides());
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});

addSourceMaterialization([&](OpBuilder &builder, Type resultType,
Expand All @@ -113,34 +72,18 @@ class LoopTypeConverter : public TypeConverter {
}
};

class PtrToUnrankedMemrefConverter : public TypeConverter {
public:
PtrToUnrankedMemrefConverter() {
addConversion([](Type type) { return type; });
addConversion([](triton::PointerType ptrType) {
return UnrankedMemRefType::get(ptrType.getPointeeType(), 0);
});
addTargetMaterialization([&](OpBuilder &builder,
UnrankedMemRefType resultType,
ValueRange inputs,
Location loc) -> Value {
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});
}
};

class StructuredToMemrefPass
: public triton::impl::StructuredToMemrefBase<StructuredToMemrefPass> {
using StructuredToMemrefBase<StructuredToMemrefPass>::StructuredToMemrefBase;

public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<func::FuncDialect, arith::ArithDialect, math::MathDialect,
linalg::LinalgDialect, affine::AffineDialect,
scf::SCFDialect, tensor::TensorDialect,
bufferization::BufferizationDialect, triton::TritonDialect,
ttx::TritonTilingExtDialect, memref::MemRefDialect>();
registry
.insert<tptr::TPtrDialect, func::FuncDialect, arith::ArithDialect,
math::MathDialect, linalg::LinalgDialect, affine::AffineDialect,
scf::SCFDialect, tensor::TensorDialect,
bufferization::BufferizationDialect, triton::TritonDialect,
ttx::TritonTilingExtDialect, memref::MemRefDialect>();
}

void runOnOperation() override {
Expand All @@ -165,11 +108,6 @@ class StructuredToMemrefPass
triton::populateStructuredToMemrefConversionPatterns(patterns,
typeConverter);

LoopTypeConverter loopTypeConverter(patterns.getContext());

mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
loopTypeConverter, patterns, target);

if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
signalPassFailure();
}
Expand Down
Loading