@@ -194,13 +194,14 @@ static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc,
194194 Value dest, OpFoldResult destOffsetVar,
195195 int64_t length) {
196196 assert (length > 0 && " length must be greater than 0" );
197+ Value destOffsetVal =
198+ getValueOrCreateConstantIndexOp (rewriter, loc, destOffsetVar);
197199 for (int i = 0 ; i < length; ++i) {
198- Value insertLoc =
199- i == 0
200- ? destOffsetVar.dyn_cast <Value>()
201- : rewriter.create <arith::AddIOp>(
202- loc, rewriter.getIndexType (), destOffsetVar.dyn_cast <Value>(),
203- rewriter.create <arith::ConstantIndexOp>(loc, i));
200+ auto insertLoc = i == 0
201+ ? destOffsetVal
202+ : rewriter.create <arith::AddIOp>(
203+ loc, rewriter.getIndexType (), destOffsetVal,
204+ rewriter.create <arith::ConstantIndexOp>(loc, i));
204205 auto extractOp = rewriter.create <vector::ExtractOp>(loc, source, i);
205206 dest = rewriter.create <vector::InsertOp>(loc, extractOp, dest, insertLoc);
206207 }
@@ -465,18 +466,16 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
465466 emulatedVectorLoad (rewriter, loc, adaptor.getBase (), linearizedIndices,
466467 numElements, oldElementType, newElementType);
467468
468- if (foldedIntraVectorOffset) {
469- if (isUnalignedEmulation) {
470- result =
471- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
472- *foldedIntraVectorOffset, origElements);
473- }
474- } else {
469+ if (!foldedIntraVectorOffset) {
475470 auto resultVector = rewriter.create <arith::ConstantOp>(
476471 loc, op.getType (), rewriter.getZeroAttr (op.getType ()));
477472 result = dynamicallyExtractSubVector (
478473 rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
479474 linearizedInfo.intraDataOffset , origElements);
475+ } else if (isUnalignedEmulation) {
476+ result =
477+ staticallyExtractSubvector (rewriter, loc, op.getType (), result,
478+ *foldedIntraVectorOffset, origElements);
480479 }
481480 rewriter.replaceOp (op, result);
482481 return success ();
@@ -571,7 +570,7 @@ struct ConvertVectorMaskedLoad final
571570 ? getConstantIntValue (linearizedInfo.intraDataOffset )
572571 : 0 ;
573572
574- auto maxIntraDataOffset = foldedIntraVectorOffset.value_or (scale - 1 );
573+ int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or (scale - 1 );
575574 FailureOr<Operation *> newMask = getCompressedMaskOp (
576575 rewriter, loc, op.getMask (), origElements, scale, maxIntraDataOffset);
577576 if (failed (newMask))
@@ -586,15 +585,13 @@ struct ConvertVectorMaskedLoad final
586585
587586 auto emptyVector = rewriter.create <arith::ConstantOp>(
588587 loc, newBitcastType, rewriter.getZeroAttr (newBitcastType));
589- if (foldedIntraVectorOffset) {
590- if (isUnalignedEmulation) {
591- passthru = staticallyInsertSubvector (
592- rewriter, loc, passthru, emptyVector, *foldedIntraVectorOffset);
593- }
594- } else {
588+ if (!foldedIntraVectorOffset) {
595589 passthru = dynamicallyInsertSubVector (
596590 rewriter, loc, dyn_cast<TypedValue<VectorType>>(passthru),
597591 emptyVector, linearizedInfo.intraDataOffset , origElements);
592+ } else if (isUnalignedEmulation) {
593+ passthru = staticallyInsertSubvector (rewriter, loc, passthru, emptyVector,
594+ *foldedIntraVectorOffset);
598595 }
599596 auto newPassThru =
600597 rewriter.create <vector::BitCastOp>(loc, loadType, passthru);
@@ -616,29 +613,25 @@ struct ConvertVectorMaskedLoad final
616613 // TODO: try to fold if op's mask is constant
617614 auto emptyMask = rewriter.create <arith::ConstantOp>(
618615 loc, newSelectMaskType, rewriter.getZeroAttr (newSelectMaskType));
619- if (foldedIntraVectorOffset) {
620- if (isUnalignedEmulation) {
621- mask = staticallyInsertSubvector (rewriter, loc, op.getMask (), emptyMask,
622- *foldedIntraVectorOffset);
623- }
624- } else {
616+ if (!foldedIntraVectorOffset) {
625617 mask = dynamicallyInsertSubVector (
626618 rewriter, loc, dyn_cast<TypedValue<VectorType>>(mask), emptyMask,
627619 linearizedInfo.intraDataOffset , origElements);
620+ } else if (isUnalignedEmulation) {
621+ mask = staticallyInsertSubvector (rewriter, loc, op.getMask (), emptyMask,
622+ *foldedIntraVectorOffset);
628623 }
629624
630625 Value result =
631626 rewriter.create <arith::SelectOp>(loc, mask, bitCast, passthru);
632- if (foldedIntraVectorOffset) {
633- if (isUnalignedEmulation) {
634- result =
635- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
636- *foldedIntraVectorOffset, origElements);
637- }
638- } else {
627+ if (!foldedIntraVectorOffset) {
639628 result = dynamicallyExtractSubVector (
640629 rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
641630 op.getPassThru (), linearizedInfo.intraDataOffset , origElements);
631+ } else if (isUnalignedEmulation) {
632+ result =
633+ staticallyExtractSubvector (rewriter, loc, op.getType (), result,
634+ *foldedIntraVectorOffset, origElements);
642635 }
643636 rewriter.replaceOp (op, result);
644637
@@ -696,7 +689,7 @@ struct ConvertVectorTransferRead final
696689 ? getConstantIntValue (linearizedInfo.intraDataOffset )
697690 : 0 ;
698691
699- auto maxIntraDataOffset = foldedIntraVectorOffset.value_or (scale - 1 );
692+ int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or (scale - 1 );
700693 auto numElements =
701694 llvm::divideCeil (maxIntraDataOffset + origElements, scale);
702695
@@ -709,18 +702,16 @@ struct ConvertVectorTransferRead final
709702 loc, VectorType::get (numElements * scale, oldElementType), newRead);
710703
711704 Value result = bitCast->getResult (0 );
712- if (foldedIntraVectorOffset) {
713- if (isUnalignedEmulation) {
714- result =
715- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
716- *foldedIntraVectorOffset, origElements);
717- }
718- } else {
705+ if (!foldedIntraVectorOffset) {
719706 auto zeros = rewriter.create <arith::ConstantOp>(
720707 loc, op.getType (), rewriter.getZeroAttr (op.getType ()));
721708 result = dynamicallyExtractSubVector (rewriter, loc, bitCast, zeros,
722709 linearizedInfo.intraDataOffset ,
723710 origElements);
711+ } else if (isUnalignedEmulation) {
712+ result =
713+ staticallyExtractSubvector (rewriter, loc, op.getType (), result,
714+ *foldedIntraVectorOffset, origElements);
724715 }
725716 rewriter.replaceOp (op, result);
726717
0 commit comments