Skip to content

Commit 4a09409

Browse files
authored
[MLIR] Make 1-D memref flattening a prerequisite for vector narrow type emulation (#157771)
Addresses: #115653 We already have utilities to flatten memrefs into 1-D. This change makes memref flattening a prerequisite for vector narrow type emulation, ensuring that emulation patterns only need to handle 1-D scenarios.
1 parent 8b3c91c commit 4a09409

File tree

6 files changed

+203
-27
lines changed

6 files changed

+203
-27
lines changed

mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,10 @@ FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp,
145145
/// ```
146146
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns);
147147

148+
/// Patterns for flattening multi-dimensional memref operations into
149+
/// one-dimensional memref operations.
150+
void populateFlattenVectorOpsOnMemrefPatterns(RewritePatternSet &patterns);
151+
void populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns);
148152
void populateFlattenMemrefsPatterns(RewritePatternSet &patterns);
149153

150154
/// Build a new memref::AllocaOp whose dynamic sizes are independent of all

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,16 @@ void populateVectorNarrowTypeEmulationPatterns(
383383
const arith::NarrowTypeEmulationConverter &typeConverter,
384384
RewritePatternSet &patterns, bool disableAtomicRMW = false);
385385

386+
/// Populates patterns for both MeMref flattening and Vector narrow type
387+
/// emulation.
388+
///
389+
/// Patterns for narrow-type-emulation require "flattened" MemRef(s), so this
390+
/// composite populate* method can be used for narrow-type-emulation for Ops
391+
/// operating on MemRef(s) that are rank > 2.
392+
void populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(
393+
arith::NarrowTypeEmulationConverter &typeConverter,
394+
RewritePatternSet &patterns);
395+
386396
/// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
387397
/// vector operations comprising `shuffle` and `bitwise` ops.
388398
/// Warning: these patterns currently only work for little endian targets.

mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -271,16 +271,26 @@ struct FlattenMemrefsPass
271271

272272
} // namespace
273273

274-
void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
275-
patterns.insert<MemRefRewritePattern<memref::LoadOp>,
276-
MemRefRewritePattern<memref::StoreOp>,
277-
MemRefRewritePattern<memref::AllocOp>,
278-
MemRefRewritePattern<memref::AllocaOp>,
279-
MemRefRewritePattern<vector::LoadOp>,
274+
void memref::populateFlattenVectorOpsOnMemrefPatterns(
275+
RewritePatternSet &patterns) {
276+
patterns.insert<MemRefRewritePattern<vector::LoadOp>,
280277
MemRefRewritePattern<vector::StoreOp>,
281278
MemRefRewritePattern<vector::TransferReadOp>,
282279
MemRefRewritePattern<vector::TransferWriteOp>,
283280
MemRefRewritePattern<vector::MaskedLoadOp>,
284281
MemRefRewritePattern<vector::MaskedStoreOp>>(
285282
patterns.getContext());
286283
}
284+
285+
void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
286+
patterns.insert<MemRefRewritePattern<memref::LoadOp>,
287+
MemRefRewritePattern<memref::StoreOp>,
288+
MemRefRewritePattern<memref::AllocOp>,
289+
MemRefRewritePattern<memref::AllocaOp>>(
290+
patterns.getContext());
291+
}
292+
293+
void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
294+
populateFlattenMemrefOpsPatterns(patterns);
295+
populateFlattenVectorOpsOnMemrefPatterns(patterns);
296+
}

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
#include <cstdint>
3939
#include <optional>
4040

41+
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
42+
4143
using namespace mlir;
4244

4345
#define DEBUG_TYPE "vector-narrow-type-emulation"
@@ -556,7 +558,6 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
556558
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
557559
ConversionPatternRewriter &rewriter) const override {
558560

559-
// See #115653
560561
if (op.getValueToStore().getType().getRank() != 1)
561562
return rewriter.notifyMatchFailure(op,
562563
"only 1-D vectors are supported ATM");
@@ -817,7 +818,13 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
817818
// ConvertVectorMaskedStore
818819
//===----------------------------------------------------------------------===//
819820

820-
// TODO: Document-me
821+
/// Converts `vector.maskedstore` operations on narrow element types to work
822+
/// with wider, byte-aligned container types by adjusting the mask and using
823+
/// bitcasting.
824+
///
825+
/// Example: Storing `vector<6xi4>` is emulated by bitcasting to `vector<3xi8>`
826+
/// (each `i8` container element holds two `i4` values) and storing with an
827+
/// adjusted mask .
821828
struct ConvertVectorMaskedStore final
822829
: OpConversionPattern<vector::MaskedStoreOp> {
823830
using OpConversionPattern::OpConversionPattern;
@@ -826,10 +833,10 @@ struct ConvertVectorMaskedStore final
826833
matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
827834
ConversionPatternRewriter &rewriter) const override {
828835

829-
// See #115653
836+
// Prerequisite: memref in the vector.maskedstore op is flattened into 1-D.
830837
if (op.getValueToStore().getType().getRank() != 1)
831-
return rewriter.notifyMatchFailure(op,
832-
"only 1-D vectors are supported ATM");
838+
return rewriter.notifyMatchFailure(
839+
op, "Memref in vector.maskedstore op must be flattened beforehand.");
833840

834841
auto loc = op.getLoc();
835842
auto containerElemTy =
@@ -931,18 +938,27 @@ struct ConvertVectorMaskedStore final
931938
// ConvertVectorLoad
932939
//===----------------------------------------------------------------------===//
933940

934-
// TODO: Document-me
941+
/// Converts `vector.load` on narrow element types to work with
942+
/// wider, byte-aligned container types by adjusting load sizes and using
943+
/// bitcasting.
944+
///
945+
/// Example: `vector.load` of `vector<4xi4>` from `memref<3x4xi4>` is emulated
946+
/// by loading `vector<2xi8>` from the linearized `memref<6xi8>` (each `i8`
947+
/// container holds two `i4` values) and bitcasting back.
948+
///
949+
/// There are cases where the number of elements to load is not byte-aligned. In
950+
/// those cases, loads are converted to byte-aligned, byte-sized loads and the
951+
/// target vector is extracted from the loaded vector.
935952
struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
936953
using OpConversionPattern::OpConversionPattern;
937954

938955
LogicalResult
939956
matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
940957
ConversionPatternRewriter &rewriter) const override {
941-
942-
// See #115653
958+
// Prerequisite: memref in the vector.load op is flattened into 1-D.
943959
if (op.getVectorType().getRank() != 1)
944-
return rewriter.notifyMatchFailure(op,
945-
"only 1-D vectors are supported ATM");
960+
return rewriter.notifyMatchFailure(
961+
op, "Memref in emulated vector ops must be flattened beforehand.");
946962

947963
auto loc = op.getLoc();
948964
auto containerElemTy =
@@ -961,8 +977,6 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
961977

962978
// Adjust the number of elements to load when emulating narrow types,
963979
// and then cast back to the original type with vector.bitcast op.
964-
// Here only the 1-D vector load is considered, and the N-D memref types
965-
// should be linearized.
966980
// For example, to emulate i4 to i8, the following op:
967981
//
968982
// %1 = vector.load %0[%c0, %c0] : memref<3x4xi4>, vector<4xi4>
@@ -1037,18 +1051,22 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
10371051
// ConvertVectorMaskedLoad
10381052
//===----------------------------------------------------------------------===//
10391053

1040-
// TODO: Document-me
1054+
/// Converts `vector.maskedload` operations on narrow element types to work with
1055+
/// wider, byte-aligned container types by adjusting the mask and using
1056+
/// bitcasting.
1057+
///
1058+
/// Example: Loading `vector<6xi4>` is emulated by loading `vector<3xi8>` and
1059+
/// bitcasting, since each `i8` container element holds two `i4` values.
10411060
struct ConvertVectorMaskedLoad final
10421061
: OpConversionPattern<vector::MaskedLoadOp> {
10431062
using OpConversionPattern::OpConversionPattern;
10441063

10451064
LogicalResult
10461065
matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
10471066
ConversionPatternRewriter &rewriter) const override {
1048-
// See #115653
10491067
if (op.getVectorType().getRank() != 1)
1050-
return rewriter.notifyMatchFailure(op,
1051-
"only 1-D vectors are supported ATM");
1068+
return rewriter.notifyMatchFailure(
1069+
op, "Memref in emulated vector ops must be flattened beforehand.");
10521070

10531071
auto loc = op.getLoc();
10541072

@@ -1229,7 +1247,6 @@ static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,
12291247

12301248
int elemsPerMultiByte = multiByteBits / subByteBits;
12311249

1232-
// TODO: This is a bit too restrictive for vectors rank > 1.
12331250
return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
12341251
}
12351252

@@ -1246,10 +1263,11 @@ struct ConvertVectorTransferRead final
12461263
matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
12471264
ConversionPatternRewriter &rewriter) const override {
12481265

1249-
// See #115653
1266+
// Prerequisites: memref in the vector.transfer_read op is flattened into
1267+
// 1-D.
12501268
if (op.getVectorType().getRank() != 1)
1251-
return rewriter.notifyMatchFailure(op,
1252-
"only 1-D vectors are supported ATM");
1269+
return rewriter.notifyMatchFailure(
1270+
op, "Memref in emulated vector ops must be flattened beforehand.");
12531271

12541272
auto loc = op.getLoc();
12551273
auto containerElemTy =
@@ -2227,7 +2245,6 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
22272245
void vector::populateVectorNarrowTypeEmulationPatterns(
22282246
const arith::NarrowTypeEmulationConverter &typeConverter,
22292247
RewritePatternSet &patterns, bool disableAtomicRMW) {
2230-
22312248
// Populate `vector.*` conversion patterns.
22322249
// TODO: #119553 support atomicity
22332250
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
@@ -2266,3 +2283,10 @@ void vector::populateVectorTransposeNarrowTypeRewritePatterns(
22662283
RewritePatternSet &patterns, PatternBenefit benefit) {
22672284
patterns.add<RewriteVectorTranspose>(patterns.getContext(), benefit);
22682285
}
2286+
2287+
void vector::populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(
2288+
arith::NarrowTypeEmulationConverter &typeConverter,
2289+
RewritePatternSet &patterns) {
2290+
memref::populateFlattenVectorOpsOnMemrefPatterns(patterns);
2291+
vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns);
2292+
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// RUN: mlir-opt --test-memref-flatten-and-vector-narrow-type-emulation --split-input-file %s | FileCheck %s
2+
3+
// This test verifies that narrow-type-emulation works correctly for
4+
// rank > 1 memrefs by combining memref flattening with vector narrow type
5+
// emulation patterns.
6+
//
7+
// The patterns tested here demonstrate the composition of two transformations,
8+
// memref flattening for vector ops and vector op narrow type emulation.
9+
//
10+
// TODO: Support `vector.transfer_write` operation.
11+
12+
func.func @vector_load_2d_i4(%arg0: index) -> vector<8xi4> {
13+
%0 = memref.alloc() : memref<4x8xi4>
14+
%1 = vector.load %0[%arg0, %arg0] : memref<4x8xi4>, vector<8xi4>
15+
return %1 : vector<8xi4>
16+
}
17+
// CHECK-LABEL: func @vector_load_2d_i4
18+
// CHECK: vector.load {{.*}} memref<16xi8>
19+
20+
// -----
21+
22+
func.func @vector_maskedload_2d_i4(%arg0: index, %passthru: vector<8xi4>) -> vector<8xi4> {
23+
%0 = memref.alloc() : memref<4x8xi4>
24+
%mask = vector.constant_mask [6] : vector<8xi1>
25+
%1 = vector.maskedload %0[%arg0, %arg0], %mask, %passthru :
26+
memref<4x8xi4>, vector<8xi1>, vector<8xi4> into vector<8xi4>
27+
return %1 : vector<8xi4>
28+
}
29+
// CHECK-LABEL: func @vector_maskedload_2d_i4(
30+
// CHECK: vector.maskedload {{.*}} memref<16xi8>
31+
32+
// -----
33+
34+
func.func @vector_maskedstore_2d_i4(%arg0: index, %value: vector<8xi4>) {
35+
%0 = memref.alloc() : memref<4x8xi4>
36+
%mask = vector.constant_mask [5] : vector<8xi1>
37+
vector.maskedstore %0[%arg0, %arg0], %mask, %value :
38+
memref<4x8xi4>, vector<8xi1>, vector<8xi4>
39+
return
40+
}
41+
// CHECK-LABEL: func @vector_maskedstore_2d_i4(
42+
// CHECK: vector.maskedstore {{.*}} memref<16xi8>
43+
44+
// -----
45+
46+
func.func @vector_store_2d_i4(%arg0: index, %value: vector<8xi4>) {
47+
%0 = memref.alloc() : memref<4x8xi4>
48+
vector.store %value, %0[%arg0, %arg0] : memref<4x8xi4>, vector<8xi4>
49+
return
50+
}
51+
// CHECK-LABEL: func @vector_store_2d_i4(
52+
// CHECK: vector.store {{.*}} memref<16xi8>
53+
54+
// -----
55+
56+
func.func @vector_transfer_read_2d_i4(%arg0: index, %padding: i4) -> vector<8xi4> {
57+
%0 = memref.alloc() : memref<4x8xi4>
58+
%1 = vector.transfer_read %0[%arg0, %arg0], %padding {in_bounds = [true]} : memref<4x8xi4>, vector<8xi4>
59+
return %1 : vector<8xi4>
60+
}
61+
// CHECK-LABEL: func @vector_transfer_read_2d_i4(
62+
// CHECK-SAME: %{{.*}}: index, %[[PADDING_I4:.*]]: i4)
63+
// CHECK: %[[PADDING_I8:.*]] = arith.extui %[[PADDING_I4]] : i4 to i8
64+
// CHECK: vector.transfer_read {{.*}}, %[[PADDING_I8]] : memref<16xi8>, vector<4xi8>

mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
1919
#include "mlir/Pass/Pass.h"
2020
#include "mlir/Transforms/DialectConversion.h"
21+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2122

2223
using namespace mlir;
2324

@@ -126,10 +127,73 @@ struct TestEmulateNarrowTypePass
126127
"normal sequence"),
127128
llvm::cl::init(false)};
128129
};
130+
131+
struct TestMemRefFlattenAndVectorNarrowTypeEmulationPass
132+
: public PassWrapper<TestMemRefFlattenAndVectorNarrowTypeEmulationPass,
133+
OperationPass<func::FuncOp>> {
134+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
135+
TestMemRefFlattenAndVectorNarrowTypeEmulationPass)
136+
137+
TestMemRefFlattenAndVectorNarrowTypeEmulationPass() = default;
138+
TestMemRefFlattenAndVectorNarrowTypeEmulationPass(
139+
const TestMemRefFlattenAndVectorNarrowTypeEmulationPass &pass)
140+
: PassWrapper(pass) {}
141+
142+
void getDependentDialects(DialectRegistry &registry) const override {
143+
registry
144+
.insert<arith::ArithDialect, func::FuncDialect, memref::MemRefDialect,
145+
vector::VectorDialect, affine::AffineDialect>();
146+
}
147+
148+
StringRef getArgument() const final {
149+
return "test-memref-flatten-and-vector-narrow-type-emulation";
150+
}
151+
152+
StringRef getDescription() const final {
153+
return "Test MemRef flattening and vector narrow type emulation patterns";
154+
}
155+
156+
void runOnOperation() override {
157+
Operation *op = getOperation();
158+
MLIRContext *ctx = &getContext();
159+
160+
// Create a type converter for narrow type emulation (8-bit)
161+
arith::NarrowTypeEmulationConverter typeConverter(8);
162+
163+
// Add conversions for memref types with i4 elements
164+
memref::populateMemRefNarrowTypeEmulationConversions(typeConverter);
165+
166+
ConversionTarget target(*ctx);
167+
target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
168+
return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
169+
});
170+
auto opLegalCallback = [&typeConverter](Operation *op) {
171+
return typeConverter.isLegal(op);
172+
};
173+
target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(opLegalCallback);
174+
target.addDynamicallyLegalDialect<
175+
arith::ArithDialect, vector::VectorDialect, memref::MemRefDialect,
176+
affine::AffineDialect>(opLegalCallback);
177+
178+
RewritePatternSet patterns(ctx);
179+
180+
// This is necessary for the purpose of emulating `memref.alloc` and
181+
// function boundaries.
182+
memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);
183+
184+
vector::populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(
185+
typeConverter, patterns);
186+
187+
// Apply partial conversion
188+
if (failed(applyPartialConversion(op, target, std::move(patterns))))
189+
signalPassFailure();
190+
}
191+
};
129192
} // namespace
130193

131194
namespace mlir::test {
132195
void registerTestEmulateNarrowTypePass() {
133196
PassRegistration<TestEmulateNarrowTypePass>();
197+
PassRegistration<TestMemRefFlattenAndVectorNarrowTypeEmulationPass>();
134198
}
135199
} // namespace mlir::test

0 commit comments

Comments
 (0)