Skip to content

Commit 9fbf8c0

Browse files
committed
Address comments
1 parent a0f94dd commit 9fbf8c0

File tree

1 file changed

+6
-11
lines changed

1 file changed

+6
-11
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -83,21 +83,21 @@ static Value createTileFromElements(PatternRewriter &rewriter, Location loc,
8383
ArrayRef<int64_t> tileOffsets,
8484
ArrayRef<int64_t> tileShape,
8585
VectorType tileType) {
86-
8786
// Initialize tile with zeros.
8887
Value tile = rewriter.create<arith::ConstantOp>(
8988
loc, tileType, rewriter.getZeroAttr(tileType));
9089

91-
// Calculate strides for both source and result shapes.
90+
// Calculate strides for source, result, and tile shapes.
9291
SmallVector<int64_t> sourceStrides = computeStrides(sourceShape);
9392
SmallVector<int64_t> resultStrides = computeStrides(resultShape);
93+
SmallVector<int64_t> tileStrides = computeStrides(tileShape);
94+
int64_t numElementsInTile = computeProduct(tileShape);
9495

9596
// Iterate over all positions in the tile using linear indexing.
96-
for (int64_t linearTileIdx = 0; linearTileIdx < computeProduct(tileShape);
97+
for (int64_t linearTileIdx = 0; linearTileIdx < numElementsInTile;
9798
++linearTileIdx) {
9899
// Convert linear tile index to multi-dimensional tile position.
99-
SmallVector<int64_t> tilePosition =
100-
delinearize(linearTileIdx, computeStrides(tileShape));
100+
SmallVector<int64_t> tilePosition = delinearize(linearTileIdx, tileStrides);
101101

102102
// Calculate the global position in the result.
103103
SmallVector<int64_t> globalResultPos;
@@ -108,18 +108,13 @@ static Value createTileFromElements(PatternRewriter &rewriter, Location loc,
108108

109109
// Convert result position to linear index.
110110
int64_t linearIndex = linearize(globalResultPos, resultStrides);
111-
112111
// Convert linear index to source position.
113-
SmallVector<int64_t> sourcePos =
114-
delinearize(linearIndex, computeStrides(sourceShape));
115-
112+
SmallVector<int64_t> sourcePos = delinearize(linearIndex, sourceStrides);
116113
// Extract element from source.
117114
Value element = vector::ExtractOp::create(rewriter, loc, source, sourcePos);
118-
119115
// Insert element into tile.
120116
tile = vector::InsertOp::create(rewriter, loc, element, tile, tilePosition);
121117
}
122-
123118
return tile;
124119
}
125120

0 commit comments

Comments
 (0)