Skip to content

Commit 5a2070b

Browse files
committed
Simplify computeIndices
1 parent 3f40948 commit 5a2070b

File tree

1 file changed

+18
-22
lines changed

1 file changed

+18
-22
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -54,32 +54,27 @@ static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
5454
return slicedIndices;
5555
}
5656

57-
// Compute the new indices for vector.load/store by adding `offsets` to
58-
// `originalIndices`.
59-
// It assumes m <= n (m = offsets.size(), n = originalIndices.size())
60-
// Last m of `originalIndices` will be updated.
57+
// Compute the new indices by adding `offsets` to `originalIndices`.
58+
// If m < n (m = offsets.size(), n = originalIndices.size()),
59+
// then only the trailing m values in `originalIndices` are updated.
6160
static SmallVector<Value> computeIndices(PatternRewriter &rewriter,
6261
Location loc,
6362
ArrayRef<Value> originalIndices,
6463
ArrayRef<int64_t> offsets) {
6564
assert(offsets.size() <= originalIndices.size() &&
6665
"Offsets should not exceed the number of original indices");
6766
SmallVector<Value> indices(originalIndices);
68-
auto originalIter = originalIndices.rbegin();
69-
auto offsetsIter = offsets.rbegin();
70-
auto indicesIter = indices.rbegin();
71-
while (offsetsIter != offsets.rend()) {
72-
Value original = *originalIter;
73-
int64_t offset = *offsetsIter;
74-
if (offset != 0)
75-
*indicesIter = rewriter.create<arith::AddIOp>(
76-
loc, original, rewriter.create<arith::ConstantIndexOp>(loc, offset));
77-
originalIter++;
78-
offsetsIter++;
79-
indicesIter++;
67+
68+
auto start = indices.size() - offsets.size();
69+
for (auto [i, offset] : llvm::enumerate(offsets)) {
70+
if (offset != 0) {
71+
indices[start + i] = rewriter.create<arith::AddIOp>(
72+
loc, originalIndices[start + i],
73+
rewriter.create<arith::ConstantIndexOp>(loc, offset));
74+
}
8075
}
8176
return indices;
82-
};
77+
}
8378

8479
// Clones `op` into a new operations that takes `operands` and returns
8580
// `resultTypes`.
@@ -658,7 +653,6 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
658653
vector::UnrollVectorOptions options;
659654
};
660655

661-
// clang-format off
662656
// This pattern unrolls the vector load into multiple 1D vector loads by
663657
// extracting slices from the base memory and inserting them into the result
664658
// vector using vector.insert_strided_slice.
@@ -667,11 +661,13 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
667661
// is converted to :
668662
// %cst = arith.constant dense<0.0> : vector<4x4xf32>
669663
// %slice_0 = vector.load %base[%indices] : memref<4x4xf32>, vector<4xf32>
670-
// %result_0 = vector.insert_strided_slice %slice_0, %cst {offsets = [0, 0], strides = [1]} : vector<4xf32> into vector<4x4xf32>
671-
// %slice_1 = vector.load %base[%indices + 1] : memref<4x4xf32>, vector<4xf32>
672-
// %result_1 = vector.insert_strided_slice %slice_1, %result_0 {offsets = [1, 0], strides = [1]} : vector<4xf32> into vector<4x4xf32>
664+
// %result_0 = vector.insert_strided_slice %slice_0, %cst
665+
// {offsets = [0, 0], strides = [1]} : vector<4xf32> into vector<4x4xf32>
666+
// %slice_1 = vector.load %base[%indices + 1]
667+
// : memref<4x4xf32>, vector<4xf32>
668+
// %result_1 = vector.insert_strided_slice %slice_1, %result_0
669+
// {offsets = [1, 0], strides = [1]} : vector<4xf32> into vector<4x4xf32>
673670
// ...
674-
// clang-format on
675671
struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
676672
UnrollLoadPattern(MLIRContext *context,
677673
const vector::UnrollVectorOptions &options,

0 commit comments

Comments
 (0)