Skip to content

Commit 2580b46

Browse files
committed
Update tests
1 parent 2effa6a commit 2580b46

File tree

3 files changed

+93
-37
lines changed

3 files changed

+93
-37
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 19 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
130130
return newMask;
131131
}
132132

133+
/// A wrapper function for emitting `vector.extract_strided_slice`.
133134
static Value extractSubvectorFrom(RewriterBase &rewriter, Location loc,
134135
VectorType extractType, Value vector,
135136
int64_t frontOffset, int64_t subvecSize) {
@@ -142,6 +143,7 @@ static Value extractSubvectorFrom(RewriterBase &rewriter, Location loc,
142143
->getResult(0);
143144
}
144145

146+
/// A wrapper function for emitting `vector.insert_strided_slice`.
145147
static Value insertSubvectorInto(RewriterBase &rewriter, Location loc,
146148
Value src, Value dest, int64_t offset) {
147149
auto offsets = rewriter.getI64ArrayAttr({offset});
@@ -150,36 +152,14 @@ static Value insertSubvectorInto(RewriterBase &rewriter, Location loc,
150152
dest, offsets, strides);
151153
}
152154

155+
/// Extracts `lengthSubvec` elements from `srcVec` into `destVec` starting at
156+
/// the offset specified by `srcOffsetVar`. Use this function when
157+
/// `srcOffsetVar` is not a constant, making it impossible to use
158+
/// vector.extract_strided_slice, as it requires constant offsets.
153159
static void dynamicallyExtractElementsToVector(
154160
RewriterBase &rewriter, Location loc, TypedValue<VectorType> srcVec,
155-
Value destVec, OpFoldResult srcOffsetVar, int64_t loopSize) {
156-
/*
157-
// Create affine maps for the lower and upper bounds
158-
AffineMap lowerBoundMap = AffineMap::getConstantMap(0, rewriter.getContext());
159-
AffineMap upperBoundMap =
160-
AffineMap::getConstantMap(loopSize, rewriter.getContext());
161-
162-
auto forLoop = rewriter.create<affine::AffineForOp>(
163-
loc, ValueRange{}, lowerBoundMap, ValueRange{}, upperBoundMap, 1,
164-
ArrayRef<Value>(destVec));
165-
166-
OpBuilder builder =
167-
OpBuilder::atBlockEnd(forLoop.getBody(), rewriter.getListener());
168-
169-
auto iv = forLoop.getInductionVar();
170-
171-
auto loopDestVec = forLoop.getRegionIterArgs()[0];
172-
auto extractLoc = builder.create<arith::AddIOp>(
173-
loc, rewriter.getIndexType(), srcOffsetVar.dyn_cast<Value>(), iv);
174-
auto extractElemOp = builder.create<vector::ExtractElementOp>(
175-
loc, elemType, srcVec, extractLoc);
176-
auto insertElemOp = builder.create<vector::InsertElementOp>(
177-
loc, extractElemOp, loopDestVec, iv);
178-
builder.create<affine::AffineYieldOp>(loc,
179-
ValueRange{insertElemOp->getResult(0)});
180-
return forLoop->getResult(0);
181-
*/
182-
for (int i = 0; i < loopSize; ++i) {
161+
Value destVec, OpFoldResult srcOffsetVar, int64_t lengthSubvec) {
162+
for (int i = 0; i < lengthSubvec; ++i) {
183163
Value extractLoc;
184164
if (i == 0) {
185165
extractLoc = srcOffsetVar.dyn_cast<Value>();
@@ -194,15 +174,21 @@ static void dynamicallyExtractElementsToVector(
194174
}
195175
}
196176

177+
/// Load `numLoadedElements` of `newElementType` from `base` at
178+
/// `linearizedIndices`, then bitcast the result into a vector of
179+
/// `oldElementType`.
197180
static TypedValue<VectorType>
198181
emulatedVectorLoad(ConversionPatternRewriter &rewriter, Location loc,
199-
Value base, OpFoldResult linearizedIndices, int64_t numBytes,
200-
int64_t scale, Type oldElememtType, Type newElementType) {
182+
Value base, OpFoldResult linearizedIndices,
183+
int64_t numLoadedElements, Type oldElememtType,
184+
Type newElementType) {
185+
auto scale = newElementType.getIntOrFloatBitWidth() /
186+
oldElememtType.getIntOrFloatBitWidth();
201187
auto newLoad = rewriter.create<vector::LoadOp>(
202-
loc, VectorType::get(numBytes, newElementType), base,
188+
loc, VectorType::get(numLoadedElements, newElementType), base,
203189
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
204190
return rewriter.create<vector::BitCastOp>(
205-
loc, VectorType::get(numBytes * scale, oldElememtType), newLoad);
191+
loc, VectorType::get(numLoadedElements * scale, oldElememtType), newLoad);
206192
};
207193

208194
namespace {
@@ -443,7 +429,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
443429
llvm::divideCeil(maxintraDataOffset + origElements, scale);
444430
Value result =
445431
emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
446-
numElements, scale, oldElementType, newElementType);
432+
numElements, oldElementType, newElementType);
447433

448434
if (foldedIntraVectorOffset) {
449435
if (isUnalignedEmulation) {

mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@
22

33
// CHECK: #map = affine_map<()[s0, s1] -> ((s0 * 3 + s1) floordiv 4)>
44
// CHECK: #map1 = affine_map<()[s0, s1] -> ((s0 * 3 + s1) mod 4)>
5-
func.func @vector_load_i2(%arg1: index, %arg2: index) -> vector<3x3xi2> {
5+
func.func @vector_load_i2(%arg1: index, %arg2: index) -> vector<3xi2> {
66
%0 = memref.alloc() : memref<3x3xi2>
77
%c0 = arith.constant 0 : index
88
%c2 = arith.constant 2 : index
99
%cst = arith.constant dense<0> : vector<3x3xi2>
1010
%1 = vector.load %0[%arg1, %arg2] : memref<3x3xi2>, vector<3xi2>
11-
%2 = vector.insert %1, %cst [0] : vector<3xi2> into vector<3x3xi2>
12-
return %2 : vector<3x3xi2>
11+
return %1 : vector<3xi2>
1312
}
1413

1514
// CHECK: func @vector_load_i2

mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,25 @@ func.func @vector_load_i2(%arg1: index, %arg2: index) -> vector<3x3xi2> {
1919

2020
//-----
2121

22+
func.func @vector_load_i2_unaligned(%arg1: index, %arg2: index) -> vector<3x3xi2> {
23+
%0 = memref.alloc() : memref<3x3xi2>
24+
%c0 = arith.constant 0 : index
25+
%c1 = arith.constant 1 : index
26+
%cst = arith.constant dense<0> : vector<3x3xi2>
27+
%1 = vector.load %0[%c0, %c1] : memref<3x3xi2>, vector<3xi2>
28+
%2 = vector.insert %1, %cst [0] : vector<3xi2> into vector<3x3xi2>
29+
return %2 : vector<3x3xi2>
30+
}
31+
32+
// CHECK: func @vector_load_i2_unaligned
33+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
34+
// CHECK: %[[INDEX:.+]] = arith.constant 0 : index
35+
// CHECK: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<3xi8>, vector<1xi8>
36+
// CHECK: %[[VEC_I2:.+]] = vector.bitcast %[[VEC]] : vector<1xi8> to vector<4xi2>
37+
// CHECK: %[[EXCTRACT:.+]] = vector.extract_strided_slice %[[VEC_I2]] {offsets = [1], sizes = [3], strides = [1]} : vector<4xi2> to vector<3xi2>
38+
39+
//-----
40+
2241
func.func @vector_transfer_read_i2() -> vector<3xi2> {
2342
%0 = memref.alloc() : memref<3x3xi2>
2443
%c0i2 = arith.constant 0 : i2
@@ -37,6 +56,26 @@ func.func @vector_transfer_read_i2() -> vector<3xi2> {
3756

3857
//-----
3958

59+
func.func @vector_transfer_read_i2_unaligned() -> vector<3xi2> {
60+
%0 = memref.alloc() : memref<3x3xi2>
61+
%c0i2 = arith.constant 0 : i2
62+
%c0 = arith.constant 0 : index
63+
%c1 = arith.constant 1 : index
64+
%1 = vector.transfer_read %0[%c0, %c1], %c0i2 {in_bounds = [true]} : memref<3x3xi2>, vector<3xi2>
65+
return %1 : vector<3xi2>
66+
}
67+
68+
// CHECK: func @vector_transfer_read_i2_unaligned
69+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
70+
// CHECK: %[[PAD:.+]] = arith.constant 0 : i2
71+
// CHECK: %[[EXT:.+]] = arith.extui %[[PAD]] : i2 to i8
72+
// CHECK: %[[INDEX:.+]] = arith.constant 0 : index
73+
// CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %[[EXT]] : memref<3xi8>, vector<1xi8>
74+
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[READ]] : vector<1xi8> to vector<4xi2>
75+
// CHECK: vector.extract_strided_slice %[[BITCAST]] {offsets = [1], sizes = [3], strides = [1]} : vector<4xi2> to vector<3xi2>
76+
77+
//-----
78+
4079
func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
4180
%0 = memref.alloc() : memref<3x5xi2>
4281
%cst = arith.constant dense<0> : vector<3x5xi2>
@@ -64,4 +103,36 @@ func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
64103
// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[ORIGINMASK]], %[[CST2]]
65104
// CHECK-SAME: {offsets = [2], strides = [1]} : vector<5xi1> into vector<8xi1>
66105
// CHECK: %[[SELECT:.+]] = arith.select %[[INSERT2]], %[[BITCAST2]], %[[INSERT1]] : vector<8xi1>, vector<8xi2>
67-
// CHECK: vector.extract_strided_slice %[[SELECT]] {offsets = [2], sizes = [5], strides = [1]} : vector<8xi2> to vector<5xi2>
106+
// CHECK: vector.extract_strided_slice %[[SELECT]] {offsets = [2], sizes = [5], strides = [1]} : vector<8xi2> to vector<5xi2>
107+
108+
//-----
109+
110+
func.func @vector_cst_maskedload_i2_unaligned(%passthru: vector<5xi2>) -> vector<3x5xi2> {
111+
%0 = memref.alloc() : memref<3x5xi2>
112+
%cst = arith.constant dense<0> : vector<3x5xi2>
113+
%mask = vector.constant_mask [3] : vector<5xi1>
114+
%c0 = arith.constant 0 : index
115+
%c1 = arith.constant 1 : index
116+
%1 = vector.maskedload %0[%c0, %c1], %mask, %passthru :
117+
memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
118+
%2 = vector.insert %1, %cst [0] : vector<5xi2> into vector<3x5xi2>
119+
return %2 : vector<3x5xi2>
120+
}
121+
122+
123+
// CHECK: func @vector_cst_maskedload_i2_unaligned
124+
// CHECK: %[[ORIGINMASK:.+]] = vector.constant_mask [3] : vector<5xi1>
125+
// CHECK: %[[NEWMASK:.+]] = arith.constant dense<[true, false]> : vector<2xi1>
126+
// CHECK: %[[VESSEL:.+]] = arith.constant dense<0> : vector<8xi2>
127+
// CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %arg0, %[[VESSEL]]
128+
// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi2> into vector<8xi2>
129+
// CHECK: %[[BITCAST1:.+]] = vector.bitcast %[[INSERT1]] : vector<8xi2> to vector<2xi8>
130+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
131+
// CHECK: %[[MASKEDLOAD:.+]] = vector.maskedload %alloc[%[[C0]]], %[[NEWMASK:.+]], %[[BITCAST1]]
132+
// CHECK-SAME: : memref<4xi8>, vector<2xi1>, vector<2xi8> into vector<2xi8>
133+
// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[MASKEDLOAD]] : vector<2xi8> to vector<8xi2>
134+
// CHECK: %[[CST2:.+]] = arith.constant dense<false> : vector<8xi1>
135+
// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[ORIGINMASK]], %[[CST2]]
136+
// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi1> into vector<8xi1>
137+
// CHECK: %[[SELECT:.+]] = arith.select %[[INSERT2]], %[[BITCAST2]], %[[INSERT1]] : vector<8xi1>, vector<8xi2>
138+
// CHECK: vector.extract_strided_slice %[[SELECT]] {offsets = [1], sizes = [5], strides = [1]} : vector<8xi2> to vector<5xi2>

0 commit comments

Comments
 (0)