@@ -729,8 +729,8 @@ static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
729729
730730 // TODO: consider relaxing this restriction in the future if we find ways
731731 // to really work with subbyte elements across the MLIR/LLVM boundary.
732- unsigned resultBitwidth = preconditionType.getElementTypeBitWidth ();
733- if (resultBitwidth % 8 != 0 )
732+ unsigned bitwidth = preconditionType.getElementTypeBitWidth ();
733+ if (bitwidth % 8 != 0 )
734734 return rewriter.notifyMatchFailure (op, " bitwidth is not k * 8" );
735735
736736 return success ();
@@ -768,6 +768,10 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
768768 (dstElemBitwidth % srcElemBitwidth) != 0 )
769769 return rewriter.notifyMatchFailure (op, " Not a supported aligned case" );
770770
771+ if ((srcType.getShape ().back () % 2 ) != 0 )
772+ return rewriter.notifyMatchFailure (
773+ op, " Not an even number of i4 elements in trailing dim" );
774+
771775 return success ();
772776}
773777
@@ -876,6 +880,58 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
876880 return rewriter.create <vector::InterleaveOp>(loc, low, high);
877881}
878882
883+ // / Rewrite the i8 -> i4 truncation into a sequence of shuffles and bitwise ops
884+ // / that take advantage of high-level information to avoid leaving LLVM to
885+ // / scramble with peephole optimizations.
886+ static Value rewriteI8ToI4Trunc (PatternRewriter &rewriter, Location loc,
887+ Value srcValue) {
888+ VectorType srcVecType = cast<VectorType>(srcValue.getType ());
889+ assert (srcVecType.getElementType ().isSignlessInteger (8 ) &&
890+ " Expected i8 type" );
891+
892+ // 1. De-interleave low and high i8 elements.
893+ int64_t vecDimSize = srcVecType.getShape ().back ();
894+ SmallVector<int64_t > deinterleaveLowMaskValues;
895+ SmallVector<int64_t > deinterleaveHighMaskValues;
896+ assert ((vecDimSize % 2 ) == 0 && " Odd number of i4 elements" );
897+ deinterleaveLowMaskValues.reserve (vecDimSize / 2 );
898+ deinterleaveHighMaskValues.reserve (vecDimSize / 2 );
899+ for (int i = 0 , end = vecDimSize; i < end; i += 2 ) {
900+ deinterleaveLowMaskValues.push_back (i);
901+ deinterleaveHighMaskValues.push_back (i + 1 );
902+ }
903+
904+ auto lowShuffleOp = rewriter.create <vector::ShuffleOp>(
905+ loc, srcValue, srcValue,
906+ rewriter.getI64ArrayAttr (deinterleaveLowMaskValues));
907+ auto highShuffleOp = rewriter.create <vector::ShuffleOp>(
908+ loc, srcValue, srcValue,
909+ rewriter.getI64ArrayAttr (deinterleaveHighMaskValues));
910+
911+ // 2. Zero out the upper side of each low i8 element.
912+ constexpr int8_t i8LowBitMask = 0x0F ;
913+ Value zeroOutMask = rewriter.create <arith::ConstantOp>(
914+ loc,
915+ DenseElementsAttr::get (lowShuffleOp.getResultVectorType (), i8LowBitMask));
916+ Value zeroOutLow =
917+ rewriter.create <arith::AndIOp>(loc, lowShuffleOp, zeroOutMask);
918+
919+ // 3. Move high i4 values to upper side of the byte.
920+ constexpr int8_t bitsToShift = 4 ;
921+ VectorType deinterI8VecType = highShuffleOp.getResultVectorType ();
922+ auto shiftValues = rewriter.create <arith::ConstantOp>(
923+ loc, DenseElementsAttr::get (deinterI8VecType, bitsToShift));
924+ Value shlHigh =
925+ rewriter.create <arith::ShLIOp>(loc, highShuffleOp, shiftValues);
926+
927+ // 4. Merge high and low i4 values.
928+ auto mergedHiLowOp = rewriter.create <arith::OrIOp>(loc, zeroOutLow, shlHigh);
929+
930+ // 5. Generate a bitcast vector<Xxi8> -> vector<2Xxi4>.
931+ auto i4VecType = srcVecType.cloneWith (std::nullopt , rewriter.getI4Type ());
932+ return rewriter.create <vector::BitCastOp>(loc, i4VecType, mergedHiLowOp);
933+ }
934+
879935namespace {
880936// / Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
881937// / advantage of high-level information to avoid leaving LLVM to scramble with
@@ -1019,7 +1075,7 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
10191075
10201076 LogicalResult matchAndRewrite (ConversionOpType conversionOp,
10211077 PatternRewriter &rewriter) const override {
1022- // Set up the BitCastRewriter and verify the preconditions.
1078+ // Verify the preconditions.
10231079 Value srcValue = conversionOp.getIn ();
10241080 auto srcVecType = dyn_cast<VectorType>(srcValue.getType ());
10251081 auto dstVecType = dyn_cast<VectorType>(conversionOp.getType ());
@@ -1043,6 +1099,65 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
10431099 }
10441100};
10451101
1102+ // / Rewrite the i8 -> i4 part of any truncation into a sequence of shuffles and
1103+ // / bitwise ops that take advantage of high-level information to avoid leaving
1104+ // / LLVM to scramble with peephole optimizations.
1105+ // /
1106+ // / For example:
1107+ // / arith.trunci %in : vector<8xi32> to vector<8xi4>
1108+ // / is rewriten as
1109+ // /
1110+ // / %cst = arith.constant dense<15> : vector<4xi8>
1111+ // / %cst_0 = arith.constant dense<4> : vector<4xi8>
1112+ // / %0 = arith.trunci %in : vector<8xi32> to vector<8xi8>
1113+ // / %1 = vector.shuffle %0, %0 [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
1114+ // / %2 = vector.shuffle %0, %0 [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
1115+ // / %3 = arith.andi %1, %cst : vector<4xi8>
1116+ // / %4 = arith.shli %2, %cst_0 : vector<4xi8>
1117+ // / %5 = arith.ori %3, %4 : vector<4xi8>
1118+ // / %6 = vector.bitcast %5 : vector<4xi8> to vector<8xi4>
1119+ // /
1120+ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
1121+ using OpRewritePattern<arith::TruncIOp>::OpRewritePattern;
1122+
1123+ LogicalResult matchAndRewrite (arith::TruncIOp truncOp,
1124+ PatternRewriter &rewriter) const override {
1125+ // Verify the preconditions.
1126+ Value srcValue = truncOp.getIn ();
1127+ auto srcVecType = dyn_cast<VectorType>(srcValue.getType ());
1128+ auto dstVecType = dyn_cast<VectorType>(truncOp.getType ());
1129+ if (!srcVecType || !dstVecType)
1130+ return failure ();
1131+
1132+ // Only single dim vectors are supported until we have
1133+ // `vector.deinterleave`.
1134+ if (srcVecType.getRank () != 1 )
1135+ return failure ();
1136+
1137+ if (failed (commonConversionPrecondition (rewriter, srcVecType, truncOp)))
1138+ return failure ();
1139+
1140+ // Check general alignment preconditions. We invert the src/dst type order
1141+ // to reuse the existing precondition logic.
1142+ if (failed (alignedConversionPrecondition (rewriter, dstVecType, srcVecType,
1143+ truncOp)))
1144+ return failure ();
1145+
1146+ // Create a new iX -> i8 truncation op.
1147+ Location loc = truncOp.getLoc ();
1148+ auto i8VecType = srcVecType.cloneWith (std::nullopt , rewriter.getI8Type ());
1149+ Value i8TruncVal =
1150+ rewriter.create <arith::TruncIOp>(loc, i8VecType, srcValue);
1151+
1152+ // Rewrite the i8 -> i4 truncation part.
1153+ Value subByteTrunc = rewriteI8ToI4Trunc (rewriter, loc, i8TruncVal);
1154+
1155+ // Finalize the rewrite.
1156+ rewriter.replaceOp (truncOp, subByteTrunc);
1157+ return success ();
1158+ }
1159+ };
1160+
10461161// / Rewrite a sub-byte vector transpose into a sequence of instructions that
10471162// / perform the transpose on wider (byte) element types.
10481163// / For example:
@@ -1115,8 +1230,9 @@ void vector::populateVectorNarrowTypeRewritePatterns(
11151230 // Patterns for aligned cases. We set higher priority as they are expected to
11161231 // generate better performance for aligned cases.
11171232 patterns.add <RewriteAlignedSubByteIntSignedExt<arith::ExtSIOp>,
1118- RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>>(
1119- patterns.getContext (), benefit.getBenefit () + 1 );
1233+ RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>,
1234+ RewriteAlignedSubByteIntTrunc>(patterns.getContext (),
1235+ benefit.getBenefit () + 1 );
11201236}
11211237
11221238void vector::populateVectorTransposeNarrowTypeRewritePatterns (
0 commit comments