@@ -54,32 +54,27 @@ static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
5454 return slicedIndices;
5555}
5656
57- // Compute the new indices for vector.load/store by adding `offsets` to
58- // `originalIndices`.
59- // It assumes m <= n (m = offsets.size(), n = originalIndices.size())
60- // Last m of `originalIndices` will be updated.
57+ // Compute the new indices by adding `offsets` to `originalIndices`.
58+ // If m < n (m = offsets.size(), n = originalIndices.size()),
59+ // then only the trailing m values in `originalIndices` are updated.
6160static SmallVector<Value> computeIndices (PatternRewriter &rewriter,
6261 Location loc,
6362 ArrayRef<Value> originalIndices,
6463 ArrayRef<int64_t > offsets) {
6564 assert (offsets.size () <= originalIndices.size () &&
6665 " Offsets should not exceed the number of original indices" );
6766 SmallVector<Value> indices (originalIndices);
68- auto originalIter = originalIndices.rbegin ();
69- auto offsetsIter = offsets.rbegin ();
70- auto indicesIter = indices.rbegin ();
71- while (offsetsIter != offsets.rend ()) {
72- Value original = *originalIter;
73- int64_t offset = *offsetsIter;
74- if (offset != 0 )
75- *indicesIter = rewriter.create <arith::AddIOp>(
76- loc, original, rewriter.create <arith::ConstantIndexOp>(loc, offset));
77- originalIter++;
78- offsetsIter++;
79- indicesIter++;
67+
68+ auto start = indices.size () - offsets.size ();
69+ for (auto [i, offset] : llvm::enumerate (offsets)) {
70+ if (offset != 0 ) {
71+ indices[start + i] = rewriter.create <arith::AddIOp>(
72+ loc, originalIndices[start + i],
73+ rewriter.create <arith::ConstantIndexOp>(loc, offset));
74+ }
8075 }
8176 return indices;
82- };
77+ }
8378
8479// Clones `op` into a new operations that takes `operands` and returns
8580// `resultTypes`.
@@ -658,7 +653,6 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
658653 vector::UnrollVectorOptions options;
659654};
660655
661- // clang-format off
662656// This pattern unrolls the vector load into multiple 1D vector loads by
663657// extracting slices from the base memory and inserting them into the result
664658// vector using vector.insert_strided_slice.
@@ -667,11 +661,13 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
667661// is converted to :
668662// %cst = arith.constant dense<0.0> : vector<4x4xf32>
669663// %slice_0 = vector.load %base[%indices] : memref<4x4xf32>, vector<4xf32>
670- // %result_0 = vector.insert_strided_slice %slice_0, %cst {offsets = [0, 0], strides = [1]} : vector<4xf32> into vector<4x4xf32>
671- // %slice_1 = vector.load %base[%indices + 1] : memref<4x4xf32>, vector<4xf32>
672- // %result_1 = vector.insert_strided_slice %slice_1, %result_0 {offsets = [1, 0], strides = [1]} : vector<4xf32> into vector<4x4xf32>
664+ // %result_0 = vector.insert_strided_slice %slice_0, %cst
665+ // {offsets = [0, 0], strides = [1]} : vector<4xf32> into vector<4x4xf32>
666+ // %slice_1 = vector.load %base[%indices + 1]
667+ // : memref<4x4xf32>, vector<4xf32>
668+ // %result_1 = vector.insert_strided_slice %slice_1, %result_0
669+ // {offsets = [1, 0], strides = [1]} : vector<4xf32> into vector<4x4xf32>
673670// ...
674- // clang-format on
675671struct UnrollLoadPattern : public OpRewritePattern <vector::LoadOp> {
676672 UnrollLoadPattern (MLIRContext *context,
677673 const vector::UnrollVectorOptions &options,
0 commit comments