Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp,
/// ```
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns);

/// Patterns for flattening multi-dimensional memref operations into
/// one-dimensional memref operations.
void populateFlattenVectorOpsOnMemrefPatterns(RewritePatternSet &patterns);
void populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns);
void populateFlattenMemrefsPatterns(RewritePatternSet &patterns);

/// Build a new memref::AllocaOp whose dynamic sizes are independent of all
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,16 @@ void populateVectorNarrowTypeEmulationPatterns(
const arith::NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns, bool disableAtomicRMW = false);

/// Populates patterns for both MeMref flattening and Vector narrow type
/// emulation.
///
/// Patterns for narrow-type-emulation require "flattened" MemRef(s), so this
/// composite populate* method can be used for narrow-type-emulation for Ops
/// operating on MemRef(s) that are rank > 2.
void populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(
arith::NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns);

/// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
/// vector operations comprising `shuffle` and `bitwise` ops.
/// Warning: these patterns currently only work for little endian targets.
Expand Down
22 changes: 16 additions & 6 deletions mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,16 +271,26 @@ struct FlattenMemrefsPass

} // namespace

void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
patterns.insert<MemRefRewritePattern<memref::LoadOp>,
MemRefRewritePattern<memref::StoreOp>,
MemRefRewritePattern<memref::AllocOp>,
MemRefRewritePattern<memref::AllocaOp>,
MemRefRewritePattern<vector::LoadOp>,
void memref::populateFlattenVectorOpsOnMemrefPatterns(
RewritePatternSet &patterns) {
patterns.insert<MemRefRewritePattern<vector::LoadOp>,
MemRefRewritePattern<vector::StoreOp>,
MemRefRewritePattern<vector::TransferReadOp>,
MemRefRewritePattern<vector::TransferWriteOp>,
MemRefRewritePattern<vector::MaskedLoadOp>,
MemRefRewritePattern<vector::MaskedStoreOp>>(
patterns.getContext());
}

void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
patterns.insert<MemRefRewritePattern<memref::LoadOp>,
MemRefRewritePattern<memref::StoreOp>,
MemRefRewritePattern<memref::AllocOp>,
MemRefRewritePattern<memref::AllocaOp>>(
patterns.getContext());
}

void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
populateFlattenMemrefOpsPatterns(patterns);
populateFlattenVectorOpsOnMemrefPatterns(patterns);
}
66 changes: 45 additions & 21 deletions mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
#include <cstdint>
#include <optional>

#include "mlir/Dialect/MemRef/Transforms/Transforms.h"

using namespace mlir;

#define DEBUG_TYPE "vector-narrow-type-emulation"
Expand Down Expand Up @@ -556,7 +558,6 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

// See #115653
if (op.getValueToStore().getType().getRank() != 1)
return rewriter.notifyMatchFailure(op,
"only 1-D vectors are supported ATM");
Expand Down Expand Up @@ -817,7 +818,13 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// ConvertVectorMaskedStore
//===----------------------------------------------------------------------===//

// TODO: Document-me
/// Converts `vector.maskedstore` operations on narrow element types to work
/// with wider, byte-aligned container types by adjusting the mask and using
/// bitcasting.
///
/// Example: Storing `vector<6xi4>` is emulated by bitcasting to `vector<3xi8>`
/// (each `i8` container element holds two `i4` values) and storing with an
/// adjusted mask .
struct ConvertVectorMaskedStore final
: OpConversionPattern<vector::MaskedStoreOp> {
using OpConversionPattern::OpConversionPattern;
Expand All @@ -826,10 +833,10 @@ struct ConvertVectorMaskedStore final
matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

// See #115653
// Prerequisite: memref in the vector.maskedstore op is flattened into 1-D.
if (op.getValueToStore().getType().getRank() != 1)
return rewriter.notifyMatchFailure(op,
"only 1-D vectors are supported ATM");
return rewriter.notifyMatchFailure(
op, "Memref in vector.maskedstore op must be flattened beforehand.");

auto loc = op.getLoc();
auto containerElemTy =
Expand Down Expand Up @@ -931,18 +938,27 @@ struct ConvertVectorMaskedStore final
// ConvertVectorLoad
//===----------------------------------------------------------------------===//

// TODO: Document-me
/// Converts `vector.load` on narrow element types to work with
/// wider, byte-aligned container types by adjusting load sizes and using
/// bitcasting.
///
/// Example: `vector.load` of `vector<4xi4>` from `memref<3x4xi4>` is emulated
/// by loading `vector<2xi8>` from the linearized `memref<6xi8>` (each `i8`
/// container holds two `i4` values) and bitcasting back.
///
/// There are cases where the number of elements to load is not byte-aligned. In
/// those cases, loads are converted to byte-aligned, byte-sized loads and the
/// target vector is extracted from the loaded vector.
struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

// See #115653
// Prerequisite: memref in the vector.load op is flattened into 1-D.
if (op.getVectorType().getRank() != 1)
return rewriter.notifyMatchFailure(op,
"only 1-D vectors are supported ATM");
return rewriter.notifyMatchFailure(
op, "Memref in emulated vector ops must be flattened beforehand.");

auto loc = op.getLoc();
auto containerElemTy =
Expand All @@ -961,8 +977,6 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {

// Adjust the number of elements to load when emulating narrow types,
// and then cast back to the original type with vector.bitcast op.
// Here only the 1-D vector load is considered, and the N-D memref types
// should be linearized.
// For example, to emulate i4 to i8, the following op:
//
// %1 = vector.load %0[%c0, %c0] : memref<3x4xi4>, vector<4xi4>
Expand Down Expand Up @@ -1037,18 +1051,22 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
// ConvertVectorMaskedLoad
//===----------------------------------------------------------------------===//

// TODO: Document-me
/// Converts `vector.maskedload` operations on narrow element types to work with
/// wider, byte-aligned container types by adjusting the mask and using
/// bitcasting.
///
/// Example: Loading `vector<6xi4>` is emulated by loading `vector<3xi8>` and
/// bitcasting, since each `i8` container element holds two `i4` values.
struct ConvertVectorMaskedLoad final
: OpConversionPattern<vector::MaskedLoadOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// See #115653
if (op.getVectorType().getRank() != 1)
return rewriter.notifyMatchFailure(op,
"only 1-D vectors are supported ATM");
return rewriter.notifyMatchFailure(
op, "Memref in emulated vector ops must be flattened beforehand.");

auto loc = op.getLoc();

Expand Down Expand Up @@ -1229,7 +1247,6 @@ static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,

int elemsPerMultiByte = multiByteBits / subByteBits;

// TODO: This is a bit too restrictive for vectors rank > 1.
return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
}

Expand All @@ -1246,10 +1263,11 @@ struct ConvertVectorTransferRead final
matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

// See #115653
// Prerequisites: memref in the vector.transfer_read op is flattened into
// 1-D.
if (op.getVectorType().getRank() != 1)
return rewriter.notifyMatchFailure(op,
"only 1-D vectors are supported ATM");
return rewriter.notifyMatchFailure(
op, "Memref in emulated vector ops must be flattened beforehand.");

auto loc = op.getLoc();
auto containerElemTy =
Expand Down Expand Up @@ -2227,7 +2245,6 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
void vector::populateVectorNarrowTypeEmulationPatterns(
const arith::NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns, bool disableAtomicRMW) {

// Populate `vector.*` conversion patterns.
// TODO: #119553 support atomicity
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
Expand Down Expand Up @@ -2266,3 +2283,10 @@ void vector::populateVectorTransposeNarrowTypeRewritePatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<RewriteVectorTranspose>(patterns.getContext(), benefit);
}

void vector::populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(
arith::NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns) {
memref::populateFlattenVectorOpsOnMemrefPatterns(patterns);
vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

populateFlattenVectorMemrefPatterns includes patterns for vector.store, vector.transfer_read and vector.transfer_write. Can you remind me - are these supported? If not, could you add a TODO here?

Also, could you add a high-level comment specifying what combination of patterns is tested and that we are merely verifying that narrow-type-emulation works for rank > 1 memrefs? Otherwise the lack of more thorough test lines feels a bit ad-hoc.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added new tests except vector.transfer_write, added a TODO there.

Also updated comments as well.

Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// RUN: mlir-opt --test-memref-flatten-and-vector-narrow-type-emulation --split-input-file %s | FileCheck %s

// This test verifies that narrow-type-emulation works correctly for
// rank > 1 memrefs by combining memref flattening with vector narrow type
// emulation patterns.
//
// The patterns tested here demonstrate the composition of two transformations,
// memref flattening for vector ops and vector op narrow type emulation.
//
// TODO: Support `vector.transfer_write` operation.

func.func @vector_load_2d_i4(%arg0: index) -> vector<8xi4> {
%0 = memref.alloc() : memref<4x8xi4>
%1 = vector.load %0[%arg0, %arg0] : memref<4x8xi4>, vector<8xi4>
return %1 : vector<8xi4>
}
// CHECK-LABEL: func @vector_load_2d_i4
// CHECK: vector.load {{.*}} memref<16xi8>

// -----

func.func @vector_maskedload_2d_i4(%arg0: index, %passthru: vector<8xi4>) -> vector<8xi4> {
%0 = memref.alloc() : memref<4x8xi4>
%mask = vector.constant_mask [6] : vector<8xi1>
%1 = vector.maskedload %0[%arg0, %arg0], %mask, %passthru :
memref<4x8xi4>, vector<8xi1>, vector<8xi4> into vector<8xi4>
return %1 : vector<8xi4>
}
// CHECK-LABEL: func @vector_maskedload_2d_i4(
// CHECK: vector.maskedload {{.*}} memref<16xi8>

// -----

func.func @vector_maskedstore_2d_i4(%arg0: index, %value: vector<8xi4>) {
%0 = memref.alloc() : memref<4x8xi4>
%mask = vector.constant_mask [5] : vector<8xi1>
vector.maskedstore %0[%arg0, %arg0], %mask, %value :
memref<4x8xi4>, vector<8xi1>, vector<8xi4>
return
}
// CHECK-LABEL: func @vector_maskedstore_2d_i4(
// CHECK: vector.maskedstore {{.*}} memref<16xi8>

// -----

func.func @vector_store_2d_i4(%arg0: index, %value: vector<8xi4>) {
%0 = memref.alloc() : memref<4x8xi4>
vector.store %value, %0[%arg0, %arg0] : memref<4x8xi4>, vector<8xi4>
return
}
// CHECK-LABEL: func @vector_store_2d_i4(
// CHECK: vector.store {{.*}} memref<16xi8>

// -----

func.func @vector_transfer_read_2d_i4(%arg0: index, %padding: i4) -> vector<8xi4> {
%0 = memref.alloc() : memref<4x8xi4>
%1 = vector.transfer_read %0[%arg0, %arg0], %padding {in_bounds = [true]} : memref<4x8xi4>, vector<8xi4>
return %1 : vector<8xi4>
}
// CHECK-LABEL: func @vector_transfer_read_2d_i4(
// CHECK-SAME: %{{.*}}: index, %[[PADDING_I4:.*]]: i4)
// CHECK: %[[PADDING_I8:.*]] = arith.extui %[[PADDING_I4]] : i4 to i8
// CHECK: vector.transfer_read {{.*}}, %[[PADDING_I8]] : memref<16xi8>, vector<4xi8>
64 changes: 64 additions & 0 deletions mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;

Expand Down Expand Up @@ -126,10 +127,73 @@ struct TestEmulateNarrowTypePass
"normal sequence"),
llvm::cl::init(false)};
};

struct TestMemRefFlattenAndVectorNarrowTypeEmulationPass
: public PassWrapper<TestMemRefFlattenAndVectorNarrowTypeEmulationPass,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestMemRefFlattenAndVectorNarrowTypeEmulationPass)

TestMemRefFlattenAndVectorNarrowTypeEmulationPass() = default;
TestMemRefFlattenAndVectorNarrowTypeEmulationPass(
const TestMemRefFlattenAndVectorNarrowTypeEmulationPass &pass)
: PassWrapper(pass) {}

void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<arith::ArithDialect, func::FuncDialect, memref::MemRefDialect,
vector::VectorDialect, affine::AffineDialect>();
}

StringRef getArgument() const final {
return "test-memref-flatten-and-vector-narrow-type-emulation";
}

StringRef getDescription() const final {
return "Test MemRef flattening and vector narrow type emulation patterns";
}

void runOnOperation() override {
Operation *op = getOperation();
MLIRContext *ctx = &getContext();

// Create a type converter for narrow type emulation (8-bit)
arith::NarrowTypeEmulationConverter typeConverter(8);

// Add conversions for memref types with i4 elements
memref::populateMemRefNarrowTypeEmulationConversions(typeConverter);

ConversionTarget target(*ctx);
target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
});
auto opLegalCallback = [&typeConverter](Operation *op) {
return typeConverter.isLegal(op);
};
target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(opLegalCallback);
target.addDynamicallyLegalDialect<
arith::ArithDialect, vector::VectorDialect, memref::MemRefDialect,
affine::AffineDialect>(opLegalCallback);

RewritePatternSet patterns(ctx);

// This is necessary for the purpose of emulating `memref.alloc` and
// function boundaries.
memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);

vector::populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(
typeConverter, patterns);

// Apply partial conversion
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
}
};
} // namespace

namespace mlir::test {
void registerTestEmulateNarrowTypePass() {
PassRegistration<TestEmulateNarrowTypePass>();
PassRegistration<TestMemRefFlattenAndVectorNarrowTypeEmulationPass>();
}
} // namespace mlir::test