Skip to content

Commit 9e7aedf

Browse files
committed
fix according to comments
1 parent d6437e9 commit 9e7aedf

File tree

1 file changed

+32
-41
lines changed

1 file changed

+32
-41
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 32 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -194,13 +194,14 @@ static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc,
194194
Value dest, OpFoldResult destOffsetVar,
195195
int64_t length) {
196196
assert(length > 0 && "length must be greater than 0");
197+
Value destOffsetVal =
198+
getValueOrCreateConstantIndexOp(rewriter, loc, destOffsetVar);
197199
for (int i = 0; i < length; ++i) {
198-
Value insertLoc =
199-
i == 0
200-
? destOffsetVar.dyn_cast<Value>()
201-
: rewriter.create<arith::AddIOp>(
202-
loc, rewriter.getIndexType(), destOffsetVar.dyn_cast<Value>(),
203-
rewriter.create<arith::ConstantIndexOp>(loc, i));
200+
auto insertLoc = i == 0
201+
? destOffsetVal
202+
: rewriter.create<arith::AddIOp>(
203+
loc, rewriter.getIndexType(), destOffsetVal,
204+
rewriter.create<arith::ConstantIndexOp>(loc, i));
204205
auto extractOp = rewriter.create<vector::ExtractOp>(loc, source, i);
205206
dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, insertLoc);
206207
}
@@ -465,18 +466,16 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
465466
emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
466467
numElements, oldElementType, newElementType);
467468

468-
if (foldedIntraVectorOffset) {
469-
if (isUnalignedEmulation) {
470-
result =
471-
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
472-
*foldedIntraVectorOffset, origElements);
473-
}
474-
} else {
469+
if (!foldedIntraVectorOffset) {
475470
auto resultVector = rewriter.create<arith::ConstantOp>(
476471
loc, op.getType(), rewriter.getZeroAttr(op.getType()));
477472
result = dynamicallyExtractSubVector(
478473
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
479474
linearizedInfo.intraDataOffset, origElements);
475+
} else if (isUnalignedEmulation) {
476+
result =
477+
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
478+
*foldedIntraVectorOffset, origElements);
480479
}
481480
rewriter.replaceOp(op, result);
482481
return success();
@@ -571,7 +570,7 @@ struct ConvertVectorMaskedLoad final
571570
? getConstantIntValue(linearizedInfo.intraDataOffset)
572571
: 0;
573572

574-
auto maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
573+
int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
575574
FailureOr<Operation *> newMask = getCompressedMaskOp(
576575
rewriter, loc, op.getMask(), origElements, scale, maxIntraDataOffset);
577576
if (failed(newMask))
@@ -586,15 +585,13 @@ struct ConvertVectorMaskedLoad final
586585

587586
auto emptyVector = rewriter.create<arith::ConstantOp>(
588587
loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
589-
if (foldedIntraVectorOffset) {
590-
if (isUnalignedEmulation) {
591-
passthru = staticallyInsertSubvector(
592-
rewriter, loc, passthru, emptyVector, *foldedIntraVectorOffset);
593-
}
594-
} else {
588+
if (!foldedIntraVectorOffset) {
595589
passthru = dynamicallyInsertSubVector(
596590
rewriter, loc, dyn_cast<TypedValue<VectorType>>(passthru),
597591
emptyVector, linearizedInfo.intraDataOffset, origElements);
592+
} else if (isUnalignedEmulation) {
593+
passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
594+
*foldedIntraVectorOffset);
598595
}
599596
auto newPassThru =
600597
rewriter.create<vector::BitCastOp>(loc, loadType, passthru);
@@ -616,29 +613,25 @@ struct ConvertVectorMaskedLoad final
616613
// TODO: try to fold if op's mask is constant
617614
auto emptyMask = rewriter.create<arith::ConstantOp>(
618615
loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
619-
if (foldedIntraVectorOffset) {
620-
if (isUnalignedEmulation) {
621-
mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
622-
*foldedIntraVectorOffset);
623-
}
624-
} else {
616+
if (!foldedIntraVectorOffset) {
625617
mask = dynamicallyInsertSubVector(
626618
rewriter, loc, dyn_cast<TypedValue<VectorType>>(mask), emptyMask,
627619
linearizedInfo.intraDataOffset, origElements);
620+
} else if (isUnalignedEmulation) {
621+
mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
622+
*foldedIntraVectorOffset);
628623
}
629624

630625
Value result =
631626
rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru);
632-
if (foldedIntraVectorOffset) {
633-
if (isUnalignedEmulation) {
634-
result =
635-
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
636-
*foldedIntraVectorOffset, origElements);
637-
}
638-
} else {
627+
if (!foldedIntraVectorOffset) {
639628
result = dynamicallyExtractSubVector(
640629
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
641630
op.getPassThru(), linearizedInfo.intraDataOffset, origElements);
631+
} else if (isUnalignedEmulation) {
632+
result =
633+
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
634+
*foldedIntraVectorOffset, origElements);
642635
}
643636
rewriter.replaceOp(op, result);
644637

@@ -696,7 +689,7 @@ struct ConvertVectorTransferRead final
696689
? getConstantIntValue(linearizedInfo.intraDataOffset)
697690
: 0;
698691

699-
auto maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
692+
int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
700693
auto numElements =
701694
llvm::divideCeil(maxIntraDataOffset + origElements, scale);
702695

@@ -709,18 +702,16 @@ struct ConvertVectorTransferRead final
709702
loc, VectorType::get(numElements * scale, oldElementType), newRead);
710703

711704
Value result = bitCast->getResult(0);
712-
if (foldedIntraVectorOffset) {
713-
if (isUnalignedEmulation) {
714-
result =
715-
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
716-
*foldedIntraVectorOffset, origElements);
717-
}
718-
} else {
705+
if (!foldedIntraVectorOffset) {
719706
auto zeros = rewriter.create<arith::ConstantOp>(
720707
loc, op.getType(), rewriter.getZeroAttr(op.getType()));
721708
result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
722709
linearizedInfo.intraDataOffset,
723710
origElements);
711+
} else if (isUnalignedEmulation) {
712+
result =
713+
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
714+
*foldedIntraVectorOffset, origElements);
724715
}
725716
rewriter.replaceOp(op, result);
726717

0 commit comments

Comments
 (0)