@@ -162,60 +162,20 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
162162 stridedMetadata.getConstifiedMixedStrides ();
163163 SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes ();
164164 OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset ();
165+ memref::LinearizedMemRefInfo linearizedInfo;
165166 OpFoldResult linearizedIndices;
166- std::tie (std::ignore , linearizedIndices) =
167+ std::tie (linearizedInfo , linearizedIndices) =
167168 memref::getLinearizedMemRefOffsetAndSize (rewriter, loc, elementBitWidth,
168169 elementBitWidth, offset, sizes,
169170 strides, indices);
170171
171- // TODO(jerryyin): Fix the getLinearizedMemRefOffsetAndSize() function
172- // Note below doesn't give the correct result for the linearized size.
173- // Value totalSize = getValueOrCreateConstantIndexOp(
174- // rewriter, loc, linearizedInfo.linearizedSize);
175- // It computes the multiplied sizes of all dimensions instead of taking
176- // the maximum of each dimension size * stride.
177- SmallVector<AffineExpr> productExpressions;
178- unsigned sourceRank = cast<ShapedType>(src.getType ()).getRank ();
179-
180- SmallVector<AffineExpr> symbols (2 * sourceRank);
181- SmallVector<Value> offsetValues;
182- bindSymbolsList (rewriter.getContext (), MutableArrayRef{symbols});
183-
184- size_t symbolIndex = 0 ;
185- for (size_t i = 0 ; i < sourceRank; ++i) {
186- AffineExpr strideExpr, sizeExpr;
187- OpFoldResult stride = strides[i];
188- OpFoldResult size = sizes[i];
189- if (auto constantStride = getConstantIntValue (stride)) {
190- strideExpr = rewriter.getAffineConstantExpr (*constantStride);
191- } else {
192- strideExpr = symbols[symbolIndex++];
193- offsetValues.push_back (
194- getValueOrCreateConstantIndexOp (rewriter, loc, stride));
195- }
196-
197- if (auto constantSize = getConstantIntValue (size)) {
198- sizeExpr = rewriter.getAffineConstantExpr (*constantSize);
199- } else {
200- sizeExpr = symbols[symbolIndex++];
201- offsetValues.push_back (
202- getValueOrCreateConstantIndexOp (rewriter, loc, size));
203- }
204-
205- productExpressions.push_back (strideExpr * sizeExpr);
206- }
207-
208- AffineMap maxMap = AffineMap::get (
209- /* dimCount=*/ 0 , /* symbolCount=*/ symbolIndex, productExpressions,
210- rewriter.getContext ());
211- Value totalSize =
212- rewriter.create <affine::AffineMaxOp>(loc, maxMap, offsetValues);
213-
214172 // delta = bufferSize - linearizedOffset
215173 Value vectorSizeOffset =
216174 rewriter.create <arith::ConstantIndexOp>(loc, vectorSize);
217175 Value linearIndex =
218176 getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices);
177+ Value totalSize = getValueOrCreateConstantIndexOp (
178+ rewriter, loc, linearizedInfo.linearizedSize );
219179 Value delta = rewriter.create <arith::SubIOp>(loc, totalSize, linearIndex);
220180
221181 // 1) check if delta < vectorSize
0 commit comments