Skip to content

Commit a9d7260

Browse files
committed
Refactor and fixes
1 parent 0b9bdce commit a9d7260

File tree

3 files changed

+84
-145
lines changed

3 files changed

+84
-145
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,17 @@ using namespace mlir;
3838

3939
/// Returns a compressed mask. The mask value is set only if any mask is present
4040
/// in the scale range. E.g., if `scale` equals to 2, and `intraDataOffset`
41-
/// equals to 2, the following mask:
41+
/// equals to 1 (intraDataOffset strictly smaller than scale), the following
42+
/// mask:
4243
///
43-
/// %mask = [1, 1, 1, 0, 0, 0]
44+
/// %mask = [1, 1, 0, 0, 0, 0]
4445
///
4546
/// will first be padded with number of `intraDataOffset` zeros:
46-
/// %mask = [0, 0, 1, 1, 1, 0, 0, 0]
47+
/// %mask = [0, 1, 1, 0, 0, 0, 0, 0]
4748
///
4849
/// then it will return the following new compressed mask:
4950
///
50-
/// %mask = [0, 1, 1, 0]
51+
/// %mask = [1, 1, 0, 0]
5152
static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
5253
Location loc, Value mask,
5354
int origElements, int scale,
@@ -76,9 +77,6 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
7677
shape.back() = numElements;
7778
auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
7879
if (createMaskOp) {
79-
// TODO: handle the case with non-zero intraDataOffset for CreateMaskOp.
80-
if (intraDataOffset != 0)
81-
return failure();
8280
OperandRange maskOperands = createMaskOp.getOperands();
8381
size_t numMaskOperands = maskOperands.size();
8482
AffineExpr s0;
@@ -130,10 +128,18 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
130128
return newMask;
131129
}
132130

133-
/// A wrapper function for emitting `vector.extract_strided_slice`.
131+
/// A wrapper function for emitting `vector.extract_strided_slice`. The vector
132+
/// has to be of 1-D shape.
134133
static Value extractSubvectorFrom(RewriterBase &rewriter, Location loc,
135134
VectorType extractType, Value vector,
136135
int64_t frontOffset, int64_t subvecSize) {
136+
// get vector's vector type:
137+
auto vectorType = dyn_cast<VectorType>(vector.getType());
138+
assert(vectorType && "expected vector type");
139+
assert(vectorType.getShape().size() == 1 && "expected 1-D vector type");
140+
assert(extractType.getShape().size() == 1 &&
141+
"extractType must be 1-D vector type");
142+
137143
auto offsets = rewriter.getI64ArrayAttr({frontOffset});
138144
auto sizes = rewriter.getI64ArrayAttr({subvecSize});
139145
auto strides = rewriter.getI64ArrayAttr({1});
@@ -143,9 +149,17 @@ static Value extractSubvectorFrom(RewriterBase &rewriter, Location loc,
143149
->getResult(0);
144150
}
145151

146-
/// A wrapper function for emitting `vector.insert_strided_slice`.
152+
/// A wrapper function for emitting `vector.insert_strided_slice`. The source
153+
/// and dest vectors must be of 1-D shape.
147154
static Value insertSubvectorInto(RewriterBase &rewriter, Location loc,
148155
Value src, Value dest, int64_t offset) {
156+
auto srcType = dyn_cast<VectorType>(src.getType());
157+
assert(srcType && "expected vector type");
158+
assert(srcType.getShape().size() == 1 && "expected 1-D vector type");
159+
auto destType = dyn_cast<VectorType>(dest.getType());
160+
assert(destType && "expected vector type");
161+
assert(destType.getShape().size() == 1 && "expected 1-D vector type");
162+
149163
auto offsets = rewriter.getI64ArrayAttr({offset});
150164
auto strides = rewriter.getI64ArrayAttr({1});
151165
return rewriter.create<vector::InsertStridedSliceOp>(loc, dest.getType(), src,
@@ -157,24 +171,20 @@ static Value insertSubvectorInto(RewriterBase &rewriter, Location loc,
157171
/// `srcOffsetVar` is not a constant, making it impossible to use
158172
/// vector.extract_strided_slice, as it requires constant offsets.
159173
static Value dynamicallyExtractSubVector(RewriterBase &rewriter, Location loc,
160-
TypedValue<VectorType> srcVec,
161-
Value destVec,
162-
OpFoldResult srcOffsetVar,
163-
int64_t lengthSubvec) {
164-
for (int i = 0; i < lengthSubvec; ++i) {
165-
Value extractLoc;
166-
if (i == 0) {
167-
extractLoc = srcOffsetVar.dyn_cast<Value>();
168-
} else {
169-
extractLoc = rewriter.create<arith::AddIOp>(
170-
loc, rewriter.getIndexType(), srcOffsetVar.dyn_cast<Value>(),
171-
rewriter.create<arith::ConstantIndexOp>(loc, i));
172-
}
174+
TypedValue<VectorType> source,
175+
Value dest, OpFoldResult offset,
176+
int64_t numElementsToExtract) {
177+
for (int i = 0; i < numElementsToExtract; ++i) {
178+
Value extractLoc =
179+
(i == 0) ? offset.dyn_cast<Value>()
180+
: rewriter.create<arith::AddIOp>(
181+
loc, rewriter.getIndexType(), offset.dyn_cast<Value>(),
182+
rewriter.create<arith::ConstantIndexOp>(loc, i));
173183
auto extractOp =
174-
rewriter.create<vector::ExtractOp>(loc, srcVec, extractLoc);
175-
destVec = rewriter.create<vector::InsertOp>(loc, extractOp, destVec, i);
184+
rewriter.create<vector::ExtractOp>(loc, source, extractLoc);
185+
dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, i);
176186
}
177-
return destVec;
187+
return dest;
178188
}
179189

180190
/// Load `numLoadedElements` of `newElementType` from `base` at
@@ -183,15 +193,15 @@ static Value dynamicallyExtractSubVector(RewriterBase &rewriter, Location loc,
183193
static TypedValue<VectorType>
184194
emulatedVectorLoad(ConversionPatternRewriter &rewriter, Location loc,
185195
Value base, OpFoldResult linearizedIndices,
186-
int64_t numLoadedElements, Type oldElememtType,
196+
int64_t numElementsToLoad, Type oldElememtType,
187197
Type newElementType) {
188198
auto scale = newElementType.getIntOrFloatBitWidth() /
189199
oldElememtType.getIntOrFloatBitWidth();
190200
auto newLoad = rewriter.create<vector::LoadOp>(
191-
loc, VectorType::get(numLoadedElements, newElementType), base,
201+
loc, VectorType::get(numElementsToLoad, newElementType), base,
192202
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
193203
return rewriter.create<vector::BitCastOp>(
194-
loc, VectorType::get(numLoadedElements * scale, oldElememtType), newLoad);
204+
loc, VectorType::get(numElementsToLoad * scale, oldElememtType), newLoad);
195205
};
196206

197207
namespace {

mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir

Lines changed: 0 additions & 52 deletions
This file was deleted.
Lines changed: 46 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
22

3+
// CHECK: #map = affine_map<()[s0, s1] -> ((s0 * 3 + s1) floordiv 4)>
4+
// CHECK: #map1 = affine_map<()[s0, s1] -> ((s0 * 3 + s1) mod 4)>
5+
36
func.func @vector_load_i2(%arg1: index, %arg2: index) -> vector<3x3xi2> {
47
%0 = memref.alloc() : memref<3x3xi2>
58
%c0 = arith.constant 0 : index
@@ -19,25 +22,6 @@ func.func @vector_load_i2(%arg1: index, %arg2: index) -> vector<3x3xi2> {
1922

2023
//-----
2124

22-
func.func @vector_load_i2_unaligned(%arg1: index, %arg2: index) -> vector<3x3xi2> {
23-
%0 = memref.alloc() : memref<3x3xi2>
24-
%c0 = arith.constant 0 : index
25-
%c1 = arith.constant 1 : index
26-
%cst = arith.constant dense<0> : vector<3x3xi2>
27-
%1 = vector.load %0[%c0, %c1] : memref<3x3xi2>, vector<3xi2>
28-
%2 = vector.insert %1, %cst [0] : vector<3xi2> into vector<3x3xi2>
29-
return %2 : vector<3x3xi2>
30-
}
31-
32-
// CHECK: func @vector_load_i2_unaligned
33-
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
34-
// CHECK: %[[INDEX:.+]] = arith.constant 0 : index
35-
// CHECK: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<3xi8>, vector<1xi8>
36-
// CHECK: %[[VEC_I2:.+]] = vector.bitcast %[[VEC]] : vector<1xi8> to vector<4xi2>
37-
// CHECK: %[[EXCTRACT:.+]] = vector.extract_strided_slice %[[VEC_I2]] {offsets = [1], sizes = [3], strides = [1]} : vector<4xi2> to vector<3xi2>
38-
39-
//-----
40-
4125
func.func @vector_transfer_read_i2() -> vector<3xi2> {
4226
%0 = memref.alloc() : memref<3x3xi2>
4327
%c0i2 = arith.constant 0 : i2
@@ -56,26 +40,6 @@ func.func @vector_transfer_read_i2() -> vector<3xi2> {
5640

5741
//-----
5842

59-
func.func @vector_transfer_read_i2_unaligned() -> vector<3xi2> {
60-
%0 = memref.alloc() : memref<3x3xi2>
61-
%c0i2 = arith.constant 0 : i2
62-
%c0 = arith.constant 0 : index
63-
%c1 = arith.constant 1 : index
64-
%1 = vector.transfer_read %0[%c0, %c1], %c0i2 {in_bounds = [true]} : memref<3x3xi2>, vector<3xi2>
65-
return %1 : vector<3xi2>
66-
}
67-
68-
// CHECK: func @vector_transfer_read_i2_unaligned
69-
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
70-
// CHECK: %[[PAD:.+]] = arith.constant 0 : i2
71-
// CHECK: %[[EXT:.+]] = arith.extui %[[PAD]] : i2 to i8
72-
// CHECK: %[[INDEX:.+]] = arith.constant 0 : index
73-
// CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %[[EXT]] : memref<3xi8>, vector<1xi8>
74-
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[READ]] : vector<1xi8> to vector<4xi2>
75-
// CHECK: vector.extract_strided_slice %[[BITCAST]] {offsets = [1], sizes = [3], strides = [1]} : vector<4xi2> to vector<3xi2>
76-
77-
//-----
78-
7943
func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
8044
%0 = memref.alloc() : memref<3x5xi2>
8145
%cst = arith.constant dense<0> : vector<3x5xi2>
@@ -107,32 +71,49 @@ func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
10771

10872
//-----
10973

110-
func.func @vector_cst_maskedload_i2_unaligned(%passthru: vector<5xi2>) -> vector<3x5xi2> {
111-
%0 = memref.alloc() : memref<3x5xi2>
112-
%cst = arith.constant dense<0> : vector<3x5xi2>
113-
%mask = vector.constant_mask [3] : vector<5xi1>
114-
%c0 = arith.constant 0 : index
115-
%c1 = arith.constant 1 : index
116-
%1 = vector.maskedload %0[%c0, %c1], %mask, %passthru :
117-
memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
118-
%2 = vector.insert %1, %cst [0] : vector<5xi2> into vector<3x5xi2>
119-
return %2 : vector<3x5xi2>
74+
func.func @vector_load_i2_dynamic_indexing(%arg1: index, %arg2: index) -> vector<3xi2> {
75+
%0 = memref.alloc() : memref<3x3xi2>
76+
%cst = arith.constant dense<0> : vector<3x3xi2>
77+
%1 = vector.load %0[%arg1, %arg2] : memref<3x3xi2>, vector<3xi2>
78+
return %1 : vector<3xi2>
12079
}
12180

81+
// CHECK: func @vector_load_i2_dynamic_indexing
82+
// CHECK: %[[ALLOC:.+]]= memref.alloc() : memref<3xi8>
83+
// CHECK: %[[LOADADDR1:.+]] = affine.apply #map()[%arg0, %arg1]
84+
// CHECK: %[[LOADADDR2:.+]] = affine.apply #map1()[%arg0, %arg1]
85+
// CHECK: %[[EMULATED_LOAD:.+]] = vector.load %alloc[%[[LOADADDR1]]] : memref<3xi8>, vector<2xi8>
86+
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EMULATED_LOAD]] : vector<2xi8> to vector<8xi2>
87+
// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<3xi2>
88+
// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2>
89+
// CHECK: %[[C1:.+]] = arith.constant 1 : index
90+
// CHECK: %[[OFFSET:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index
91+
// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[OFFSET]]] : i2 from vector<8xi2>
92+
// CHECK: %[[C2:.+]] = arith.constant 2 : index
93+
// CHECK: %[[OFFSET2:.+]] = arith.addi %1, %c2 : index
94+
// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[OFFSET2]]] : i2 from vector<8xi2>
12295

123-
// CHECK: func @vector_cst_maskedload_i2_unaligned
124-
// CHECK: %[[ORIGINMASK:.+]] = vector.constant_mask [3] : vector<5xi1>
125-
// CHECK: %[[NEWMASK:.+]] = arith.constant dense<[true, false]> : vector<2xi1>
126-
// CHECK: %[[VESSEL:.+]] = arith.constant dense<0> : vector<8xi2>
127-
// CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %arg0, %[[VESSEL]]
128-
// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi2> into vector<8xi2>
129-
// CHECK: %[[BITCAST1:.+]] = vector.bitcast %[[INSERT1]] : vector<8xi2> to vector<2xi8>
130-
// CHECK: %[[C0:.+]] = arith.constant 0 : index
131-
// CHECK: %[[MASKEDLOAD:.+]] = vector.maskedload %alloc[%[[C0]]], %[[NEWMASK:.+]], %[[BITCAST1]]
132-
// CHECK-SAME: : memref<4xi8>, vector<2xi1>, vector<2xi8> into vector<2xi8>
133-
// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[MASKEDLOAD]] : vector<2xi8> to vector<8xi2>
134-
// CHECK: %[[CST2:.+]] = arith.constant dense<false> : vector<8xi1>
135-
// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[ORIGINMASK]], %[[CST2]]
136-
// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi1> into vector<8xi1>
137-
// CHECK: %[[SELECT:.+]] = arith.select %[[INSERT2]], %[[BITCAST2]], %[[INSERT1]] : vector<8xi1>, vector<8xi2>
138-
// CHECK: vector.extract_strided_slice %[[SELECT]] {offsets = [1], sizes = [5], strides = [1]} : vector<8xi2> to vector<5xi2>
96+
//-----
97+
98+
func.func @vector_transfer_read_i2_dynamic_indexing(%arg1: index, %arg2: index) -> vector<3xi2> {
99+
%0 = memref.alloc() : memref<3x3xi2>
100+
%pad = arith.constant 0 : i2
101+
%1 = vector.transfer_read %0[%arg1, %arg2], %pad {in_bounds = [true]} : memref<3x3xi2>, vector<3xi2>
102+
return %1 : vector<3xi2>
103+
}
104+
105+
// CHECK: func @vector_transfer_read_i2_dynamic_indexing
106+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
107+
// CHECK: %[[C0:.+]] = arith.extui %c0_i2 : i2 to i8
108+
// CHECK: %[[LOADADDR1:.+]] = affine.apply #map()[%arg0, %arg1]
109+
// CHECK: %[[LOADADDR2:.+]] = affine.apply #map1()[%arg0, %arg1]
110+
// CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[LOADADDR1]]], %[[C0]] : memref<3xi8>, vector<2xi8>
111+
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[READ]] : vector<2xi8> to vector<8xi2>
112+
// CHECK: %[[CST:.+]] = arith.constant dense<0> : vector<3xi2>
113+
// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2>
114+
// CHECK: %[[C1:.+]] = arith.constant 1 : index
115+
// CHECK: %[[ADDI:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index
116+
// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[ADDI]]] : i2 from vector<8xi2>
117+
// CHECK: %[[C2:.+]] = arith.constant 2 : index
118+
// CHECK: %[[ADDI2:.+]] = arith.addi %[[LOADADDR2]], %[[C2]] : index
119+
// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]]] : i2 from vector<8xi2>

0 commit comments

Comments
 (0)