Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/triton-shared/Conversion/StructuredToMemref/Passes.td
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef TRITON_ARITH_TO_LINALG_CONVERSION_PASSES
#define TRITON_ARITH_TO_LINALG_CONVERSION_PASSES
#ifndef STRUCTURED_TO_MEMREF_CONVERSION_PASSES
#define STRUCTURED_TO_MEMREF_CONVERSION_PASSES

include "mlir/Pass/PassBase.td"

Expand Down
160 changes: 34 additions & 126 deletions lib/Conversion/StructuredToMemref/StructuredToMemref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
//
//===----------------------------------------------------------------------===//

#include "triton/Dialect/Triton/IR/Types.h"

#include "triton-shared/Analysis/OpFoldResultUtils.h"
#include "triton-shared/Conversion/StructuredToMemref/StructuredToMemref.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h"

#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
Expand All @@ -18,22 +21,17 @@
#include "mlir/IR/Types.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton-shared/Analysis/OpFoldResultUtils.h"
#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h"

#include "triton/Dialect/Triton/IR/Dialect.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR//MemRef.h"
#include "triton/Dialect/Triton/IR/Types.h"

#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"

#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"

#include <algorithm>
#include <cassert>
Expand Down Expand Up @@ -373,28 +371,9 @@ struct MakeTensorPtrConverter
op, staticTargetOffset.value_or(ShapedType::kDynamic), staticStrides,
resultShape);

// The base ptr, which is from one of the args, would have already been
// converted to memref<*> at this point, so get the base from adaptor.
//
// For block pointers, the base could come from a sequence of `tt.addptr`,
// which at this point has already been lowered to a sequence of
// `memref.reinterpret_cast` ops. The offset in such cases are dynamic.
// (see test/Conversion/StructuredToMemref/block_ptr_complex_offset.mlir)
//
// For non-block pointer cases, the base is the reinterpret_cast of a
// function argument. Assert that the offset is a constant 0 in such cases.
auto ptr = adaptor.getBase();
if (auto reinterpretCast = ptr.getDefiningOp<memref::ReinterpretCastOp>()) {
auto offset = reinterpretCast.getMixedOffsets()[0];
auto intAttr = getIntAttr(offset);
assert(isBlockPtr || (intAttr.has_value() && intAttr.value() == 0));
targetOffset = addOFRs(targetOffset, reinterpretCast.getMixedOffsets()[0],
op->getLoc(), rewriter);
}

auto castOp = rewriter.create<memref::ReinterpretCastOp>(
op.getLoc(), resultType, ptr, targetOffset, op.getMixedSizes(),
mixedStrides);
op.getLoc(), resultType, adaptor.getBase(), targetOffset,
op.getMixedSizes(), mixedStrides);

rewriter.replaceOp(op, castOp);

Expand All @@ -421,6 +400,10 @@ struct MakeTensorPtrConverter
}

public:
MakeTensorPtrConverter(const TypeConverter &typeConverter,
MLIRContext *context)
: OpConversionPattern<tts::MakeTensorPtrOp>(typeConverter, context) {}

LogicalResult
matchAndRewrite(tts::MakeTensorPtrOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -594,8 +577,13 @@ struct LoadConverter : public OpConversionPattern<tts::LoadOp> {
// No mask
assert(!other && "other value used in non-masked load");

if (auto unrealizedCast = ptr.getDefiningOp<UnrealizedConversionCastOp>()) {
auto ptrDefiningOp = ptr.getDefiningOp();
if (ptrDefiningOp->hasAttr(WRAP_SIDE_BY_SIDE) ||
ptrDefiningOp->hasAttr(WRAP_STACKED)) {

auto unrealizedCast = cast<UnrealizedConversionCastOp>(ptrDefiningOp);
auto memrefs = unrealizedCast.getOperands();
assert(memrefs.size() == 2);
auto block1 = memrefs[0];
auto block2 = memrefs[1];

Expand Down Expand Up @@ -664,9 +652,14 @@ struct LoadConverter : public OpConversionPattern<tts::LoadOp> {
});
}

if (auto unrealizedCast = ptr.getDefiningOp<UnrealizedConversionCastOp>()) {
auto ptrDefiningOp = ptr.getDefiningOp();
if (ptrDefiningOp->hasAttr(WRAP_SIDE_BY_SIDE) ||
ptrDefiningOp->hasAttr(WRAP_STACKED)) {

auto unrealizedCast = cast<UnrealizedConversionCastOp>(ptrDefiningOp);

auto memrefs = unrealizedCast.getOperands();
assert(memrefs.size() == 2);
auto block1 = memrefs[0];
auto block2 = memrefs[1];

Expand Down Expand Up @@ -700,6 +693,9 @@ struct LoadConverter : public OpConversionPattern<tts::LoadOp> {
}

public:
LoadConverter(const TypeConverter &typeConverter, MLIRContext *context)
: OpConversionPattern<tts::LoadOp>(typeConverter, context) {}

LogicalResult
matchAndRewrite(tts::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -730,6 +726,9 @@ struct StoreConverter : public OpConversionPattern<tts::StoreOp> {
}

public:
StoreConverter(const TypeConverter &typeConverter, MLIRContext *context)
: OpConversionPattern<tts::StoreOp>(typeConverter, context) {}

LogicalResult
matchAndRewrite(tts::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -759,101 +758,10 @@ struct StoreConverter : public OpConversionPattern<tts::StoreOp> {
}
};

struct ScalarLoadConverter : public OpConversionPattern<triton::LoadOp> {
using OpConversionPattern<triton::LoadOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!op.getType().isIntOrIndexOrFloat()) {
return failure();
}

auto loc = op->getLoc();
auto memrefPtr = adaptor.getPtr();
auto zeroMap = AffineMap::getConstantMap(0, rewriter.getContext());
auto loadOp = rewriter.create<affine::AffineLoadOp>(loc, memrefPtr, zeroMap,
std::nullopt);
rewriter.replaceOp(op, loadOp.getResult());

return success();
}
};

struct ScalarStoreConverter : public OpConversionPattern<triton::StoreOp> {
private:
using OpConversionPattern<triton::StoreOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

if (!op.getValue().getType().isIntOrIndexOrFloat()) {
return failure();
}

auto loc = op->getLoc();
auto memrefPtr = adaptor.getPtr();
auto val = op.getValue();
auto zeroMap = AffineMap::getConstantMap(0, rewriter.getContext());

rewriter.create<affine::AffineStoreOp>(loc, val, memrefPtr, zeroMap,
std::nullopt);
rewriter.eraseOp(op);

return success();
}
};

struct UnrealizedCastConverter
: public OpConversionPattern<UnrealizedConversionCastOp> {
private:
using OpConversionPattern<UnrealizedConversionCastOp>::OpConversionPattern;

public:
UnrealizedCastConverter(TypeConverter &typeConverter, MLIRContext *context)
: OpConversionPattern<UnrealizedConversionCastOp>(typeConverter,
context) {}

LogicalResult
matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto resType = op->getResultTypes()[0];
auto input = op.getInputs()[0];
auto inputType = input.getType();

if (!isa<triton::PointerType>(resType) ||
!isa<MemRefType, UnrankedMemRefType>(inputType)) {
return failure();
}

if (auto reinterpretCast =
input.getDefiningOp<memref::ReinterpretCastOp>()) {
rewriter.replaceOp(op, reinterpretCast);
} else {
auto ptrType = cast<triton::PointerType>(resType);
auto memrefType =
cast<MemRefType>(getTypeConverter()->convertType(ptrType));

auto cast = rewriter.create<memref::ReinterpretCastOp>(
op->getLoc(), memrefType, op.getInputs()[0], 0 /*offset*/,
SmallVector<int64_t>{1} /*sizes*/,
SmallVector<int64_t>{1} /*strides*/);

rewriter.replaceOp(op, cast);
}

return success();
}
};

} // namespace

void mlir::triton::populateStructuredToMemrefConversionPatterns(
RewritePatternSet &patterns, TypeConverter &typeConverter) {
patterns.add<UnrealizedCastConverter>(typeConverter, patterns.getContext());
patterns.add<MakeTensorPtrConverter, LoadConverter, StoreConverter,
ScalarLoadConverter, ScalarStoreConverter>(
patterns.getContext());
patterns.add<MakeTensorPtrConverter>(typeConverter, patterns.getContext());
patterns.add<LoadConverter, StoreConverter>(patterns.getContext());
}
Loading
Loading