@@ -53,6 +53,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
5353 Location loc, Value mask,
5454 int origElements, int scale,
5555 int intraDataOffset = 0 ) {
56+ assert (intraDataOffset < scale && " intraDataOffset must be less than scale" );
5657 auto numElements = (intraDataOffset + origElements + scale - 1 ) / scale;
5758
5859 Operation *maskOp = mask.getDefiningOp ();
@@ -182,6 +183,27 @@ static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
182183 return dest;
183184}
184185
186+ // / Inserts a 1-D subvector into a 1-D `dest` vector at index `offset`.
187+ static Value dynamicallyInsertSubVector (RewriterBase &rewriter, Location loc,
188+ TypedValue<VectorType> source,
189+ Value dest, OpFoldResult destOffsetVar,
190+ int64_t length) {
191+ assert (length > 0 && " length must be greater than 0" );
192+ for (int i = 0 ; i < length; ++i) {
193+ Value insertLoc;
194+ if (i == 0 ) {
195+ insertLoc = destOffsetVar.dyn_cast <Value>();
196+ } else {
197+ insertLoc = rewriter.create <arith::AddIOp>(
198+ loc, rewriter.getIndexType (), destOffsetVar.dyn_cast <Value>(),
199+ rewriter.create <arith::ConstantIndexOp>(loc, i));
200+ }
201+ auto extractOp = rewriter.create <vector::ExtractOp>(loc, source, i);
202+ dest = rewriter.create <vector::InsertOp>(loc, extractOp, dest, insertLoc);
203+ }
204+ return dest;
205+ }
206+
185207// / Returns the op sequence for an emulated sub-byte data type vector load.
186208// / specifically, use `emulatedElemType` for loading a vector of `origElemType`.
187209// / The load location is given by `base` and `linearizedIndices`, and the
@@ -199,7 +221,7 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
199221 return rewriter.create <vector::BitCastOp>(
200222 loc, VectorType::get (numEmultedElementsToLoad * scale, origElemType),
201223 newLoad);
202- };
224+ }
203225
204226namespace {
205227
@@ -546,29 +568,30 @@ struct ConvertVectorMaskedLoad final
546568 ? getConstantIntValue (linearizedInfo.intraDataOffset )
547569 : 0 ;
548570
549- if (!foldedIntraVectorOffset) {
550- // unimplemented case for dynamic intra vector offset
551- return failure ();
552- }
553-
554- FailureOr<Operation *> newMask =
555- getCompressedMaskOp (rewriter, loc, op.getMask (), origElements, scale,
556- *foldedIntraVectorOffset);
571+ auto maxIntraDataOffset = foldedIntraVectorOffset.value_or (scale - 1 );
572+ FailureOr<Operation *> newMask = getCompressedMaskOp (
573+ rewriter, loc, op.getMask (), origElements, scale, maxIntraDataOffset);
557574 if (failed (newMask))
558575 return failure ();
559576
577+ Value passthru = op.getPassThru ();
578+
560579 auto numElements =
561- llvm::divideCeil (*foldedIntraVectorOffset + origElements, scale);
580+ llvm::divideCeil (maxIntraDataOffset + origElements, scale);
562581 auto loadType = VectorType::get (numElements, newElementType);
563582 auto newBitcastType = VectorType::get (numElements * scale, oldElementType);
564583
565- Value passthru = op.getPassThru ();
566- if (isUnalignedEmulation) {
567- // create an empty vector of the new type
568- auto emptyVector = rewriter.create <arith::ConstantOp>(
569- loc, newBitcastType, rewriter.getZeroAttr (newBitcastType));
570- passthru = staticallyInsertSubvector (rewriter, loc, passthru, emptyVector,
571- *foldedIntraVectorOffset);
584+ auto emptyVector = rewriter.create <arith::ConstantOp>(
585+ loc, newBitcastType, rewriter.getZeroAttr (newBitcastType));
586+ if (foldedIntraVectorOffset) {
587+ if (isUnalignedEmulation) {
588+ passthru = staticallyInsertSubvector (
589+ rewriter, loc, passthru, emptyVector, *foldedIntraVectorOffset);
590+ }
591+ } else {
592+ passthru = dynamicallyInsertSubVector (
593+ rewriter, loc, dyn_cast<TypedValue<VectorType>>(passthru),
594+ emptyVector, linearizedInfo.intraDataOffset , origElements);
572595 }
573596 auto newPassThru =
574597 rewriter.create <vector::BitCastOp>(loc, loadType, passthru);
@@ -585,23 +608,36 @@ struct ConvertVectorMaskedLoad final
585608 rewriter.create <vector::BitCastOp>(loc, newBitcastType, newLoad);
586609
587610 Value mask = op.getMask ();
588- if (isUnalignedEmulation) {
589- auto newSelectMaskType =
590- VectorType::get (numElements * scale, rewriter.getI1Type ());
591- // TODO: can fold if op's mask is constant
592- auto emptyVector = rewriter.create <arith::ConstantOp>(
593- loc, newSelectMaskType, rewriter.getZeroAttr (newSelectMaskType));
594- mask = staticallyInsertSubvector (rewriter, loc, op.getMask (), emptyVector,
595- *foldedIntraVectorOffset);
611+ auto newSelectMaskType =
612+ VectorType::get (numElements * scale, rewriter.getI1Type ());
613+ // TODO: try to fold if op's mask is constant
614+ auto emptyMask = rewriter.create <arith::ConstantOp>(
615+ loc, newSelectMaskType, rewriter.getZeroAttr (newSelectMaskType));
616+ if (foldedIntraVectorOffset) {
617+ if (isUnalignedEmulation) {
618+ mask = staticallyInsertSubvector (rewriter, loc, op.getMask (), emptyMask,
619+ *foldedIntraVectorOffset);
620+ }
621+ } else {
622+ mask = dynamicallyInsertSubVector (
623+ rewriter, loc, dyn_cast<TypedValue<VectorType>>(mask), emptyMask,
624+ linearizedInfo.intraDataOffset , origElements);
596625 }
597626
598627 Value result =
599628 rewriter.create <arith::SelectOp>(loc, mask, bitCast, passthru);
600-
601- if (isUnalignedEmulation) {
602- result =
603- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
604- *foldedIntraVectorOffset, origElements);
629+ if (foldedIntraVectorOffset) {
630+ if (isUnalignedEmulation) {
631+ result =
632+ staticallyExtractSubvector (rewriter, loc, op.getType (), result,
633+ *foldedIntraVectorOffset, origElements);
634+ }
635+ } else {
636+ auto resultVector = rewriter.create <arith::ConstantOp>(
637+ loc, op.getType (), rewriter.getZeroAttr (op.getType ()));
638+ result = dynamicallyExtractSubVector (
639+ rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
640+ linearizedInfo.intraDataOffset , origElements);
605641 }
606642 rewriter.replaceOp (op, result);
607643
@@ -659,10 +695,9 @@ struct ConvertVectorTransferRead final
659695 ? getConstantIntValue (linearizedInfo.intraDataOffset )
660696 : 0 ;
661697
662- auto maxIntraVectorOffset =
663- foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1 ;
698+ auto maxIntraDataOffset = foldedIntraVectorOffset.value_or (scale - 1 );
664699 auto numElements =
665- llvm::divideCeil (maxIntraVectorOffset + origElements, scale);
700+ llvm::divideCeil (maxIntraDataOffset + origElements, scale);
666701
667702 auto newRead = rewriter.create <vector::TransferReadOp>(
668703 loc, VectorType::get (numElements, newElementType), adaptor.getSource (),
0 commit comments