Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp,
/// ```
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns);

/// Patterns for flattening multi-dimensional memref operations into
/// one-dimensional memref operations.
void populateFlattenVectorOpsOnMemrefPatterns(RewritePatternSet &patterns);
void populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns);
void populateFlattenMemrefsPatterns(RewritePatternSet &patterns);

/// Build a new memref::AllocaOp whose dynamic sizes are independent of all
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,16 @@ void populateVectorNarrowTypeEmulationPatterns(
const arith::NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns, bool disableAtomicRMW = false);

/// Populates patterns for both MeMref flattening and Vector narrow type
/// emulation.
///
/// Patterns for narrow-type-emulation require "flattened" MemRef(s), so this
/// composite populate* method can be used for narrow-type-emulation for Ops
/// operating on MemRef(s) that are rank > 2.
void populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(
arith::NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns);

/// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
/// vector operations comprising `shuffle` and `bitwise` ops.
/// Warning: these patterns currently only work for little endian targets.
Expand Down
22 changes: 16 additions & 6 deletions mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,16 +271,26 @@ struct FlattenMemrefsPass

} // namespace

void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
patterns.insert<MemRefRewritePattern<memref::LoadOp>,
MemRefRewritePattern<memref::StoreOp>,
MemRefRewritePattern<memref::AllocOp>,
MemRefRewritePattern<memref::AllocaOp>,
MemRefRewritePattern<vector::LoadOp>,
void memref::populateFlattenVectorOpsOnMemrefPatterns(
RewritePatternSet &patterns) {
patterns.insert<MemRefRewritePattern<vector::LoadOp>,
MemRefRewritePattern<vector::StoreOp>,
MemRefRewritePattern<vector::TransferReadOp>,
MemRefRewritePattern<vector::TransferWriteOp>,
MemRefRewritePattern<vector::MaskedLoadOp>,
MemRefRewritePattern<vector::MaskedStoreOp>>(
patterns.getContext());
}

void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
patterns.insert<MemRefRewritePattern<memref::LoadOp>,
MemRefRewritePattern<memref::StoreOp>,
MemRefRewritePattern<memref::AllocOp>,
MemRefRewritePattern<memref::AllocaOp>>(
patterns.getContext());
}

void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
populateFlattenMemrefOpsPatterns(patterns);
populateFlattenVectorOpsOnMemrefPatterns(patterns);
}
66 changes: 45 additions & 21 deletions mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
#include <cstdint>
#include <optional>

#include "mlir/Dialect/MemRef/Transforms/Transforms.h"

using namespace mlir;

#define DEBUG_TYPE "vector-narrow-type-emulation"
Expand Down Expand Up @@ -556,7 +558,6 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

// See #115653
if (op.getValueToStore().getType().getRank() != 1)
return rewriter.notifyMatchFailure(op,
"only 1-D vectors are supported ATM");
Expand Down Expand Up @@ -817,7 +818,13 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// ConvertVectorMaskedStore
//===----------------------------------------------------------------------===//

// TODO: Document-me
/// Converts `vector.maskedstore` operations on narrow element types to work
/// with wider, byte-aligned container types by adjusting the mask and using
/// bitcasting.
///
/// Example: Storing `vector<6xi4>` is emulated by bitcasting to `vector<3xi8>`
/// (each `i8` container element holds two `i4` values) and storing with an
/// adjusted mask .
struct ConvertVectorMaskedStore final
: OpConversionPattern<vector::MaskedStoreOp> {
using OpConversionPattern::OpConversionPattern;
Expand All @@ -826,10 +833,10 @@ struct ConvertVectorMaskedStore final
matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

// See #115653
// Prerequisite: memref in the vector.maskedstore op is flattened into 1-D.
if (op.getValueToStore().getType().getRank() != 1)
return rewriter.notifyMatchFailure(op,
"only 1-D vectors are supported ATM");
return rewriter.notifyMatchFailure(
op, "Memref in vector.maskedstore op must be flattened beforehand.");

auto loc = op.getLoc();
auto containerElemTy =
Expand Down Expand Up @@ -931,18 +938,27 @@ struct ConvertVectorMaskedStore final
// ConvertVectorLoad
//===----------------------------------------------------------------------===//

// TODO: Document-me
/// Converts `vector.load` on narrow element types to work with
/// wider, byte-aligned container types by adjusting load sizes and using
/// bitcasting.
///
/// Example: `vector.load` of `vector<4xi4>` from `memref<3x4xi4>` is emulated
/// by loading `vector<2xi8>` from the linearized `memref<6xi8>` (each `i8`
/// container holds two `i4` values) and bitcasting back.
///
/// There are cases where the number of elements to load is not byte-aligned. In
/// those cases, loads are converted to byte-aligned, byte-sized loads and the
/// target vector is extracted from the loaded vector.
struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

// See #115653
// Prerequisite: memref in the vector.load op is flattened into 1-D.
if (op.getVectorType().getRank() != 1)
return rewriter.notifyMatchFailure(op,
"only 1-D vectors are supported ATM");
return rewriter.notifyMatchFailure(
op, "Memref in emulated vector ops must be flattened beforehand.");

auto loc = op.getLoc();
auto containerElemTy =
Expand All @@ -961,8 +977,6 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {

// Adjust the number of elements to load when emulating narrow types,
// and then cast back to the original type with vector.bitcast op.
// Here only the 1-D vector load is considered, and the N-D memref types
// should be linearized.
// For example, to emulate i4 to i8, the following op:
//
// %1 = vector.load %0[%c0, %c0] : memref<3x4xi4>, vector<4xi4>
Expand Down Expand Up @@ -1037,18 +1051,22 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
// ConvertVectorMaskedLoad
//===----------------------------------------------------------------------===//

// TODO: Document-me
/// Converts `vector.maskedload` operations on narrow element types to work with
/// wider, byte-aligned container types by adjusting the mask and using
/// bitcasting.
///
/// Example: Loading `vector<6xi4>` is emulated by loading `vector<3xi8>` and
/// bitcasting, since each `i8` container element holds two `i4` values.
struct ConvertVectorMaskedLoad final
: OpConversionPattern<vector::MaskedLoadOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// See #115653
if (op.getVectorType().getRank() != 1)
return rewriter.notifyMatchFailure(op,
"only 1-D vectors are supported ATM");
return rewriter.notifyMatchFailure(
op, "Memref in emulated vector ops must be flattened beforehand.");

auto loc = op.getLoc();

Expand Down Expand Up @@ -1229,7 +1247,6 @@ static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,

int elemsPerMultiByte = multiByteBits / subByteBits;

// TODO: This is a bit too restrictive for vectors rank > 1.
return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
}

Expand All @@ -1246,10 +1263,11 @@ struct ConvertVectorTransferRead final
matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

// See #115653
// Prerequisites: memref in the vector.transfer_read op is flattened into
// 1-D.
if (op.getVectorType().getRank() != 1)
return rewriter.notifyMatchFailure(op,
"only 1-D vectors are supported ATM");
return rewriter.notifyMatchFailure(
op, "Memref in emulated vector ops must be flattened beforehand.");

auto loc = op.getLoc();
auto containerElemTy =
Expand Down Expand Up @@ -2227,7 +2245,6 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
void vector::populateVectorNarrowTypeEmulationPatterns(
const arith::NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns, bool disableAtomicRMW) {

// Populate `vector.*` conversion patterns.
// TODO: #119553 support atomicity
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
Expand Down Expand Up @@ -2266,3 +2283,10 @@ void vector::populateVectorTransposeNarrowTypeRewritePatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<RewriteVectorTranspose>(patterns.getContext(), benefit);
}

void vector::populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(
arith::NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns) {
memref::populateFlattenVectorOpsOnMemrefPatterns(patterns);
vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

populateFlattenVectorMemrefPatterns includes patterns for vector.store, vector.transfer_read and vector.transfer_write. Can you remind me - are these supported? If not, could you add a TODO here?

Also, could you add a high-level comment specifying what combination of patterns is tested and that we are merely verifying that narrow-type-emulation works for rank > 1 memrefs? Otherwise the lack of more thorough test lines feels a bit ad-hoc.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added new tests except vector.transfer_write, added a TODO there.

Also updated comments as well.

Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// RUN: mlir-opt --test-memref-flatten-and-vector-narrow-type-emulation --split-input-file %s | FileCheck %s

// This test verifies that narrow-type-emulation works correctly for
// rank > 1 memrefs by combining memref flattening with vector narrow type
// emulation patterns.
//
// The patterns tested here demonstrate the composition of two transformations,
// memref flattening for vector ops and vector op narrow type emulation.
//
// TODO: Support `vector.transfer_write` operation.

func.func @vector_load_2d_i4(%arg0: index, %arg1: index) -> vector<8xi4> {
%0 = memref.alloc() : memref<4x8xi4>
%1 = vector.load %0[%arg0, %arg1] : memref<4x8xi4>, vector<8xi4>
return %1 : vector<8xi4>
}
// CHECK: func @vector_load_2d_i4
// CHECK: vector.load {{.*}} memref<16xi8>

// -----

func.func @vector_maskedload_2d_i4(%arg0: index, %arg1: index, %passthru: vector<8xi4>) -> vector<8xi4> {
%0 = memref.alloc() : memref<4x8xi4>
%mask = vector.constant_mask [6] : vector<8xi1>
%1 = vector.maskedload %0[%arg0, %arg1], %mask, %passthru :
memref<4x8xi4>, vector<8xi1>, vector<8xi4> into vector<8xi4>
return %1 : vector<8xi4>
}
// CHECK: func @vector_maskedload_2d_i4(
// CHECK: vector.maskedload {{.*}} memref<16xi8>

// -----

func.func @vector_maskedstore_2d_i4(%arg0: index, %arg1: index, %value: vector<8xi4>) {
%0 = memref.alloc() : memref<4x8xi4>
%mask = vector.constant_mask [5] : vector<8xi1>
vector.maskedstore %0[%arg0, %arg1], %mask, %value :
memref<4x8xi4>, vector<8xi1>, vector<8xi4>
return
}
// CHECK: func @vector_maskedstore_2d_i4(
// CHECK: vector.maskedstore {{.*}} memref<16xi8>

// -----

func.func @vector_store_2d_i4(%arg0: index, %arg1: index, %value: vector<8xi4>) {
%0 = memref.alloc() : memref<4x8xi4>
vector.store %value, %0[%arg0, %arg1] : memref<4x8xi4>, vector<8xi4>
return
}
// CHECK: func @vector_store_2d_i4(
// CHECK: vector.store {{.*}} memref<16xi8>

// -----

func.func @vector_transfer_read_2d_i4(%arg0: index, %arg1: index, %padding: i4) -> vector<8xi4> {
%0 = memref.alloc() : memref<4x8xi4>
%1 = vector.transfer_read %0[%arg0, %arg1], %padding {in_bounds = [true]} : memref<4x8xi4>, vector<8xi4>
return %1 : vector<8xi4>
}
// CHECK: func @vector_transfer_read_2d_i4(
// CHECK-SAME: %{{.*}}: index, %{{.*}}: index, %[[PADDING_I4:.*]]: i4)
// CHECK: %[[PADDING_I8:.*]] = arith.extui %[[PADDING_I4]] : i4 to i8
// CHECK: vector.transfer_read {{.*}}, %[[PADDING_I8]] : memref<16xi8>, vector<4xi8>
61 changes: 61 additions & 0 deletions mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;

Expand Down Expand Up @@ -126,10 +127,70 @@ struct TestEmulateNarrowTypePass
"normal sequence"),
llvm::cl::init(false)};
};

struct TestMemRefFlattenAndVectorNarrowTypeEmulationPass
: public PassWrapper<TestMemRefFlattenAndVectorNarrowTypeEmulationPass,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestMemRefFlattenAndVectorNarrowTypeEmulationPass)

TestMemRefFlattenAndVectorNarrowTypeEmulationPass() = default;
TestMemRefFlattenAndVectorNarrowTypeEmulationPass(
const TestMemRefFlattenAndVectorNarrowTypeEmulationPass &pass)
: PassWrapper(pass) {}

void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<arith::ArithDialect, func::FuncDialect, memref::MemRefDialect,
vector::VectorDialect, affine::AffineDialect>();
}

StringRef getArgument() const final {
return "test-memref-flatten-and-vector-narrow-type-emulation";
}

StringRef getDescription() const final {
return "Test MemRef flattening and vector narrow type emulation patterns";
}

void runOnOperation() override {
Operation *op = getOperation();
MLIRContext *ctx = &getContext();

// Create a type converter for narrow type emulation (8-bit)
arith::NarrowTypeEmulationConverter typeConverter(8);

// Add conversions for memref types with i4 elements
memref::populateMemRefNarrowTypeEmulationConversions(typeConverter);

ConversionTarget target(*ctx);
target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
});
auto opLegalCallback = [&typeConverter](Operation *op) {
return typeConverter.isLegal(op);
};
target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(opLegalCallback);
target.addDynamicallyLegalDialect<
arith::ArithDialect, vector::VectorDialect, memref::MemRefDialect,
affine::AffineDialect>(opLegalCallback);

RewritePatternSet patterns(ctx);

memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);
vector::populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(
typeConverter, patterns);

// Apply partial conversion
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
}
};
} // namespace

namespace mlir::test {
void registerTestEmulateNarrowTypePass() {
PassRegistration<TestEmulateNarrowTypePass>();
PassRegistration<TestMemRefFlattenAndVectorNarrowTypeEmulationPass>();
}
} // namespace mlir::test