@@ -52,7 +52,9 @@ using namespace mlir;
5252// /
5353// / %mask = [1, 1, 0, 0, 0, 0]
5454// /
55- // / will first be padded with number of `intraDataOffset` zeros:
55+ // / will first be padded in the front with number of `intraDataOffset` zeros,
56+ // / and pad zeros in the back to make the number of elements a multiple of
57+ // / `scale` (just to make it easier to compute). The new mask will be:
5658// / %mask = [0, 1, 1, 0, 0, 0, 0, 0]
5759// /
5860// / then it will return the following new compressed mask:
@@ -62,7 +64,8 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
6264 Location loc, Value mask,
6365 int origElements, int scale,
6466 int intraDataOffset = 0 ) {
65- auto numElements = (intraDataOffset + origElements + scale - 1 ) / scale;
67+ assert (intraDataOffset < scale && " intraDataOffset must be less than scale" );
68+ auto numElements = llvm::divideCeil (intraDataOffset + origElements, scale);
6669
6770 Operation *maskOp = mask.getDefiningOp ();
6871 SmallVector<vector::ExtractOp, 2 > extractOps;
@@ -194,6 +197,26 @@ static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
194197 return dest;
195198}
196199
200+ // / Inserts a 1-D subvector into a 1-D `dest` vector at index `destOffsetVar`.
201+ static Value dynamicallyInsertSubVector (RewriterBase &rewriter, Location loc,
202+ TypedValue<VectorType> source,
203+ Value dest, OpFoldResult destOffsetVar,
204+ size_t length) {
205+ assert (length > 0 && " length must be greater than 0" );
206+ Value destOffsetVal =
207+ getValueOrCreateConstantIndexOp (rewriter, loc, destOffsetVar);
208+ for (size_t i = 0 ; i < length; ++i) {
209+ auto insertLoc = i == 0
210+ ? destOffsetVal
211+ : rewriter.create <arith::AddIOp>(
212+ loc, rewriter.getIndexType (), destOffsetVal,
213+ rewriter.create <arith::ConstantIndexOp>(loc, i));
214+ auto extractOp = rewriter.create <vector::ExtractOp>(loc, source, i);
215+ dest = rewriter.create <vector::InsertOp>(loc, extractOp, dest, insertLoc);
216+ }
217+ return dest;
218+ }
219+
197220// / Returns the op sequence for an emulated sub-byte data type vector load.
198221// / specifically, use `emulatedElemType` for loading a vector of `origElemType`.
199222// / The load location is given by `base` and `linearizedIndices`, and the
@@ -466,18 +489,16 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
466489 emulatedVectorLoad (rewriter, loc, adaptor.getBase (), linearizedIndices,
467490 numElements, oldElementType, newElementType);
468491
469- if (foldedIntraVectorOffset) {
470- if (isUnalignedEmulation) {
471- result =
472- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
473- *foldedIntraVectorOffset, origElements);
474- }
475- } else {
492+ if (!foldedIntraVectorOffset) {
476493 auto resultVector = rewriter.create <arith::ConstantOp>(
477494 loc, op.getType (), rewriter.getZeroAttr (op.getType ()));
478495 result = dynamicallyExtractSubVector (
479496 rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
480497 linearizedInfo.intraDataOffset , origElements);
498+ } else if (isUnalignedEmulation) {
499+ result =
500+ staticallyExtractSubvector (rewriter, loc, op.getType (), result,
501+ *foldedIntraVectorOffset, origElements);
481502 }
482503 rewriter.replaceOp (op, result);
483504 return success ();
@@ -572,27 +593,26 @@ struct ConvertVectorMaskedLoad final
572593 ? getConstantIntValue (linearizedInfo.intraDataOffset )
573594 : 0 ;
574595
575- if (!foldedIntraVectorOffset) {
576- // unimplemented case for dynamic intra vector offset
577- return failure ();
578- }
579-
580- FailureOr<Operation *> newMask =
581- getCompressedMaskOp (rewriter, loc, op.getMask (), origElements, scale,
582- *foldedIntraVectorOffset);
596+ int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or (scale - 1 );
597+ FailureOr<Operation *> newMask = getCompressedMaskOp (
598+ rewriter, loc, op.getMask (), origElements, scale, maxIntraDataOffset);
583599 if (failed (newMask))
584600 return failure ();
585601
602+ Value passthru = op.getPassThru ();
603+
586604 auto numElements =
587- llvm::divideCeil (*foldedIntraVectorOffset + origElements, scale);
605+ llvm::divideCeil (maxIntraDataOffset + origElements, scale);
588606 auto loadType = VectorType::get (numElements, newElementType);
589607 auto newBitcastType = VectorType::get (numElements * scale, oldElementType);
590608
591- Value passthru = op.getPassThru ();
592- if (isUnalignedEmulation) {
593- // create an empty vector of the new type
594- auto emptyVector = rewriter.create <arith::ConstantOp>(
595- loc, newBitcastType, rewriter.getZeroAttr (newBitcastType));
609+ auto emptyVector = rewriter.create <arith::ConstantOp>(
610+ loc, newBitcastType, rewriter.getZeroAttr (newBitcastType));
611+ if (!foldedIntraVectorOffset) {
612+ passthru = dynamicallyInsertSubVector (
613+ rewriter, loc, dyn_cast<TypedValue<VectorType>>(passthru),
614+ emptyVector, linearizedInfo.intraDataOffset , origElements);
615+ } else if (isUnalignedEmulation) {
596616 passthru = staticallyInsertSubvector (rewriter, loc, passthru, emptyVector,
597617 *foldedIntraVectorOffset);
598618 }
@@ -611,20 +631,27 @@ struct ConvertVectorMaskedLoad final
611631 rewriter.create <vector::BitCastOp>(loc, newBitcastType, newLoad);
612632
613633 Value mask = op.getMask ();
614- if (isUnalignedEmulation) {
615- auto newSelectMaskType =
616- VectorType::get (numElements * scale, rewriter.getI1Type ());
617- // TODO: can fold if op's mask is constant
618- auto emptyVector = rewriter.create <arith::ConstantOp>(
619- loc, newSelectMaskType, rewriter.getZeroAttr (newSelectMaskType));
620- mask = staticallyInsertSubvector (rewriter, loc, op.getMask (), emptyVector,
634+ auto newSelectMaskType =
635+ VectorType::get (numElements * scale, rewriter.getI1Type ());
636+ // TODO: try to fold if op's mask is constant
637+ auto emptyMask = rewriter.create <arith::ConstantOp>(
638+ loc, newSelectMaskType, rewriter.getZeroAttr (newSelectMaskType));
639+ if (!foldedIntraVectorOffset) {
640+ mask = dynamicallyInsertSubVector (
641+ rewriter, loc, dyn_cast<TypedValue<VectorType>>(mask), emptyMask,
642+ linearizedInfo.intraDataOffset , origElements);
643+ } else if (isUnalignedEmulation) {
644+ mask = staticallyInsertSubvector (rewriter, loc, op.getMask (), emptyMask,
621645 *foldedIntraVectorOffset);
622646 }
623647
624648 Value result =
625649 rewriter.create <arith::SelectOp>(loc, mask, bitCast, passthru);
626-
627- if (isUnalignedEmulation) {
650+ if (!foldedIntraVectorOffset) {
651+ result = dynamicallyExtractSubVector (
652+ rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
653+ op.getPassThru (), linearizedInfo.intraDataOffset , origElements);
654+ } else if (isUnalignedEmulation) {
628655 result =
629656 staticallyExtractSubvector (rewriter, loc, op.getType (), result,
630657 *foldedIntraVectorOffset, origElements);
@@ -685,10 +712,9 @@ struct ConvertVectorTransferRead final
685712 ? getConstantIntValue (linearizedInfo.intraDataOffset )
686713 : 0 ;
687714
688- auto maxIntraVectorOffset =
689- foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1 ;
715+ int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or (scale - 1 );
690716 auto numElements =
691- llvm::divideCeil (maxIntraVectorOffset + origElements, scale);
717+ llvm::divideCeil (maxIntraDataOffset + origElements, scale);
692718
693719 auto newRead = rewriter.create <vector::TransferReadOp>(
694720 loc, VectorType::get (numElements, newElementType), adaptor.getSource (),
@@ -699,18 +725,16 @@ struct ConvertVectorTransferRead final
699725 loc, VectorType::get (numElements * scale, oldElementType), newRead);
700726
701727 Value result = bitCast->getResult (0 );
702- if (foldedIntraVectorOffset) {
703- if (isUnalignedEmulation) {
704- result =
705- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
706- *foldedIntraVectorOffset, origElements);
707- }
708- } else {
728+ if (!foldedIntraVectorOffset) {
709729 auto zeros = rewriter.create <arith::ConstantOp>(
710730 loc, op.getType (), rewriter.getZeroAttr (op.getType ()));
711731 result = dynamicallyExtractSubVector (rewriter, loc, bitCast, zeros,
712732 linearizedInfo.intraDataOffset ,
713733 origElements);
734+ } else if (isUnalignedEmulation) {
735+ result =
736+ staticallyExtractSubvector (rewriter, loc, op.getType (), result,
737+ *foldedIntraVectorOffset, origElements);
714738 }
715739 rewriter.replaceOp (op, result);
716740
0 commit comments