@@ -83,21 +83,21 @@ static Value createTileFromElements(PatternRewriter &rewriter, Location loc,
8383 ArrayRef<int64_t > tileOffsets,
8484 ArrayRef<int64_t > tileShape,
8585 VectorType tileType) {
86-
8786 // Initialize tile with zeros.
8887 Value tile = rewriter.create <arith::ConstantOp>(
8988 loc, tileType, rewriter.getZeroAttr (tileType));
9089
91- // Calculate strides for both source and result shapes.
90+ // Calculate strides for source, result, and tile shapes.
9291 SmallVector<int64_t > sourceStrides = computeStrides (sourceShape);
9392 SmallVector<int64_t > resultStrides = computeStrides (resultShape);
93+ SmallVector<int64_t > tileStrides = computeStrides (tileShape);
94+ int64_t numElementsInTile = computeProduct (tileShape);
9495
9596 // Iterate over all positions in the tile using linear indexing.
96- for (int64_t linearTileIdx = 0 ; linearTileIdx < computeProduct (tileShape) ;
97+ for (int64_t linearTileIdx = 0 ; linearTileIdx < numElementsInTile ;
9798 ++linearTileIdx) {
9899 // Convert linear tile index to multi-dimensional tile position.
99- SmallVector<int64_t > tilePosition =
100- delinearize (linearTileIdx, computeStrides (tileShape));
100+ SmallVector<int64_t > tilePosition = delinearize (linearTileIdx, tileStrides);
101101
102102 // Calculate the global position in the result.
103103 SmallVector<int64_t > globalResultPos;
@@ -108,18 +108,13 @@ static Value createTileFromElements(PatternRewriter &rewriter, Location loc,
108108
109109 // Convert result position to linear index.
110110 int64_t linearIndex = linearize (globalResultPos, resultStrides);
111-
112111 // Convert linear index to source position.
113- SmallVector<int64_t > sourcePos =
114- delinearize (linearIndex, computeStrides (sourceShape));
115-
112+ SmallVector<int64_t > sourcePos = delinearize (linearIndex, sourceStrides);
116113 // Extract element from source.
117114 Value element = vector::ExtractOp::create (rewriter, loc, source, sourcePos);
118-
119115 // Insert element into tile.
120116 tile = vector::InsertOp::create (rewriter, loc, element, tile, tilePosition);
121117 }
122-
123118 return tile;
124119}
125120
0 commit comments