Skip to content

Commit 0b9bdce

Browse files
committed
fix bugs
1 parent 2580b46 commit 0b9bdce

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,11 @@ static Value insertSubvectorInto(RewriterBase &rewriter, Location loc,
156156
/// the offset specified by `srcOffsetVar`. Use this function when
157157
/// `srcOffsetVar` is not a constant, making it impossible to use
158158
/// vector.extract_strided_slice, as it requires constant offsets.
159-
static void dynamicallyExtractElementsToVector(
160-
RewriterBase &rewriter, Location loc, TypedValue<VectorType> srcVec,
161-
Value destVec, OpFoldResult srcOffsetVar, int64_t lengthSubvec) {
159+
static Value dynamicallyExtractSubVector(RewriterBase &rewriter, Location loc,
160+
TypedValue<VectorType> srcVec,
161+
Value destVec,
162+
OpFoldResult srcOffsetVar,
163+
int64_t lengthSubvec) {
162164
for (int i = 0; i < lengthSubvec; ++i) {
163165
Value extractLoc;
164166
if (i == 0) {
@@ -170,8 +172,9 @@ static void dynamicallyExtractElementsToVector(
170172
}
171173
auto extractOp =
172174
rewriter.create<vector::ExtractOp>(loc, srcVec, extractLoc);
173-
rewriter.create<vector::InsertOp>(loc, extractOp, destVec, i);
175+
destVec = rewriter.create<vector::InsertOp>(loc, extractOp, destVec, i);
174176
}
177+
return destVec;
175178
}
176179

177180
/// Load `numLoadedElements` of `newElementType` from `base` at
@@ -436,15 +439,14 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
436439
result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
437440
*foldedIntraVectorOffset, origElements);
438441
}
439-
rewriter.replaceOp(op, result);
440442
} else {
441443
auto resultVector = rewriter.create<arith::ConstantOp>(
442444
loc, op.getType(), rewriter.getZeroAttr(op.getType()));
443-
dynamicallyExtractElementsToVector(
445+
result = dynamicallyExtractSubVector(
444446
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
445447
linearizedInfo.intraDataOffset, origElements);
446-
rewriter.replaceOp(op, resultVector);
447448
}
449+
rewriter.replaceOp(op, result);
448450
return success();
449451
}
450452
};
@@ -669,11 +671,11 @@ struct ConvertVectorTransferRead final
669671
*foldedIntraVectorOffset, origElements);
670672
}
671673
} else {
672-
result = rewriter.create<arith::ConstantOp>(
674+
auto zeros = rewriter.create<arith::ConstantOp>(
673675
loc, op.getType(), rewriter.getZeroAttr(op.getType()));
674-
dynamicallyExtractElementsToVector(rewriter, loc, bitCast, result,
675-
linearizedInfo.intraDataOffset,
676-
origElements);
676+
result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
677+
linearizedInfo.intraDataOffset,
678+
origElements);
677679
}
678680
rewriter.replaceOp(op, result);
679681

0 commit comments

Comments
 (0)