@@ -21,7 +21,9 @@ SPDX-License-Identifier: MIT
2121#include < llvm/Analysis/ScalarEvolution.h>
2222#include < llvm/Analysis/ScalarEvolutionExpressions.h>
2323#include < llvm/Analysis/TargetFolder.h>
24+ #include < llvm/Analysis/ValueTracking.h>
2425#include < llvm/IR/GetElementPtrTypeIterator.h>
26+ #include < llvm/Support/KnownBits.h>
2527#include < llvm/Transforms/Utils/ScalarEvolutionExpander.h>
2628#include < llvm/Transforms/Utils/Local.h>
2729#include " llvmWrapper/IR/Intrinsics.h"
@@ -63,6 +65,7 @@ class GenIRLowering : public FunctionPass {
6365
6466 bool combineFMaxFMin (CallInst *GII, BasicBlock::iterator &BBI) const ;
6567 bool combineSelectInst (SelectInst *Sel, BasicBlock::iterator &BBI) const ;
68+ bool combinePack4i8Or2i16 (Instruction *inst, uint64_t numBits) const ;
6669
6770 bool constantFoldFMaxFMin (CallInst *GII, BasicBlock::iterator &BBI) const ;
6871};
@@ -362,6 +365,15 @@ bool GenIRLowering::runOnFunction(Function &F) {
362365 Changed |= combineSelectInst (cast<SelectInst>(Inst), BI);
363366 }
364367 break ;
368+ case Instruction::Or:
369+ if (Inst->getType ()->isIntegerTy (32 )) {
370+ // Detect packing of 4 i8 values and convert to a pattern that is
371+ // matched CodeGenPatternMatch::MatchPack4i8
372+ Changed |= combinePack4i8Or2i16 (Inst, 8 /* numBits*/ );
373+ // TODO: also detect <2 x i16> packing once PatternMatch is updated
374+ // to packing of 16-bit values.
375+ }
376+ break ;
365377 }
366378 }
367379 }
@@ -1000,6 +1012,173 @@ bool GenIRLowering::combineSelectInst(SelectInst *Sel, BasicBlock::iterator &BBI
10001012 return false ;
10011013}
10021014
1015+ // //////////////////////////////////////////////////////////////////////////////
1016+ // Detect complex patterns that pack 2 16-bit or 4 8-bit integers into a 32-bit
1017+ // value. Generate equivalent sequence of instructions that is later matched in
1018+ // the CodeGenPatternMatch::MatchPack4i8().
1019+ // Pattern example for <4 x i8> packing:
1020+ // %x1 = and i32 %x, 127
1021+ // %x2 = lshr i32 %x, 24
1022+ // %x3 = and i32 %x2, 128
1023+ // %x4 = or i32 %x3, %x1
1024+ // %y1 = and i32 %y, 127
1025+ // %y2 = lshr i32 %y, 24
1026+ // %y3 = and i32 %y2, 128
1027+ // %y4 = or i32 %y3, %y1
1028+ // %y5 = shl nuw nsw i32 %y4, 8
1029+ // %xy = or i32 %x4, %y5
1030+ // %z1 = and i32 %z, 127
1031+ // %z2 = lshr i32 %z, 24
1032+ // %z3 = and i32 %z2, 128
1033+ // %z4 = or i32 %z3, %z1
1034+ // %z5 = shl nuw nsw i32 %z4, 16
1035+ // %xyz = or i32 %xy, %z5
1036+ // %w1 = shl nsw i32 %w, 24
1037+ // %w2 = and i32 %w1, 2130706432
1038+ // %w3 = and i32 %w, -2147483648
1039+ // %w4 = or i32 %w2, %w3
1040+ // %xyzw = or i32 %xyz, %w4
1041+ // and generate:
1042+ // %0 = trunc i32 %x to i8
1043+ // %1 = insertelement <4 x i8> poison, i8 %0, i32 0
1044+ // %2 = trunc i32 %y to i8
1045+ // %3 = insertelement <4 x i8> %1, i8 %2, i32 1
1046+ // %4 = trunc i32 %z to i8
1047+ // %5 = insertelement <4 x i8> %3, i8 %4, i32 2
1048+ // %6 = trunc i32 %w to i8
1049+ // %7 = insertelement <4 x i8> %5, i8 %6, i32 3
1050+ // %8 = bitcast <4 x i8> %7 to i32
1051+ bool GenIRLowering::combinePack4i8Or2i16 (Instruction *inst, uint64_t numBits) const {
1052+ using namespace llvm ::PatternMatch;
1053+
1054+ const DataLayout &DL = inst->getModule ()->getDataLayout ();
1055+ // Vector of 4 or 2 values that will be packed into a single 32-bit value.
1056+ // The std::pair contains the 32-bit value that contains the element
1057+ // to pack and the LSB where the packed value starts in the 32-bit value.
1058+ SmallVector<std::pair<Value *, uint64_t >, 4 > toPack;
1059+ IGC_ASSERT (numBits == 8 || numBits == 16 );
1060+ uint64_t packedVecSize = 32 / numBits;
1061+ toPack.resize (packedVecSize);
1062+ uint64_t cSignMask = QWBIT (numBits - 1 );
1063+ uint64_t cMagnMask = BITMASK (numBits - 1 );
1064+ // The std::pair contains the 32-bit value that contains the element
1065+ // to pack and the left shift bits that indicate the element position
1066+ // in the packed vector.
1067+ SmallVector<std::pair<Value *, uint64_t >, 4 > args;
1068+ args.push_back ({isa<BitCastInst>(inst) ? inst->getOperand (0 ) : inst, 0 });
1069+ // In the first step traverse the chain of `or` and `shl` instructions
1070+ // and find all elements of the packed vector.
1071+ while (!args.empty ()) {
1072+ auto [v, prevShlBits] = args.pop_back_val ();
1073+ Value *lOp = nullptr ;
1074+ Value *rOp = nullptr ;
1075+
1076+ // Detect left shift by multiple of `numBits`. The `shl` operation sets the
1077+ // `index` argument in the corresponding InsertElement instruction in the
1078+ // final packing sequence. This operation can also be viewed as repacking
1079+ // of already packed vector into another packed vector.
1080+ uint64_t shlBits = 0 ;
1081+ if (match (v, m_Shl (m_Value (lOp), m_ConstantInt (shlBits))) && (shlBits % numBits) == 0 ) {
1082+ args.push_back ({lOp, shlBits + prevShlBits});
1083+ continue ;
1084+ }
1085+ // Detect values that fit into `numBits` bits - a single element of
1086+ // the packed vector.
1087+ KnownBits kb = computeKnownBits (v, DL);
1088+ uint32_t nonZeroBits = ~(static_cast <uint32_t >(kb.Zero .getZExtValue ()));
1089+ uint32_t lsb = findFirstSet (nonZeroBits);
1090+ uint32_t msb = findLastSet (nonZeroBits);
1091+ if (msb != lsb && (msb / numBits) == (lsb / numBits)) {
1092+ uint32_t idx = (prevShlBits / numBits) + (lsb / numBits);
1093+ if (idx < packedVecSize && toPack[idx].first == nullptr ) {
1094+ toPack[idx] = std::make_pair (v, alignDown (lsb, numBits));
1095+ continue ;
1096+ }
1097+ }
1098+ // Detect packing of two disjoint values. This `or` operation corresponds
1099+ // to an InsertElement instruction in the final packing sequence.
1100+ if (match (v, m_Or (m_Value (lOp), m_Value (rOp)))) {
1101+ KnownBits kbL = computeKnownBits (lOp, DL);
1102+ KnownBits kbR = computeKnownBits (rOp, DL);
1103+ uint32_t nonZeroBitsL = ~(static_cast <uint32_t >(kbL.Zero .getZExtValue ()));
1104+ uint32_t nonZeroBitsR = ~(static_cast <uint32_t >(kbR.Zero .getZExtValue ()));
1105+ if ((nonZeroBitsL & nonZeroBitsR) == 0 ) {
1106+ args.push_back ({lOp, prevShlBits});
1107+ args.push_back ({rOp, prevShlBits});
1108+ }
1109+ continue ;
1110+ }
1111+ if (std::all_of (toPack.begin (), toPack.end (), [](const auto &c) { return c.first != nullptr ; })) {
1112+ break ;
1113+ }
1114+ // Unsupported pattern.
1115+ return false ;
1116+ }
1117+ if (std::any_of (toPack.begin (), toPack.end (), [](const auto &c) { return c.first == nullptr ; })) {
1118+ return false ;
1119+ }
1120+ // In the second step match the pattern that packs sign and magnitude parts
1121+ // and simple masking with `and` instruction.
1122+ for (uint32_t i = 0 ; i < packedVecSize; ++i) {
1123+ auto [v, lsb] = toPack[i];
1124+ Value *lOp = nullptr ;
1125+ Value *rOp = nullptr ;
1126+ uint64_t lMask = 0 ;
1127+ uint64_t rMask = 0 ;
1128+ // Match patterns that pack the sign and magnitude parts.
1129+ if (match (v, m_Or (m_And (m_Value (lOp), m_ConstantInt (lMask)), m_And (m_Value (rOp), m_ConstantInt (rMask)))) &&
1130+ (countPopulation (rMask) == 1 || countPopulation (lMask) == 1 )) {
1131+ Value *signOp = countPopulation (rMask) == 1 ? rOp : lOp;
1132+ Value *magnOp = countPopulation (rMask) == 1 ? lOp : rOp;
1133+ uint64_t signMask = countPopulation (rMask) == 1 ? rMask : lMask;
1134+ uint64_t magnMask = countPopulation (rMask) == 1 ? lMask : rMask;
1135+ uint64_t shlBits = 0 ;
1136+ uint64_t shrBits = 0 ;
1137+ // %b = shl nsw i32 %a, 24
1138+ // %c = and i32 %b, 2130706432
1139+ // %sign = and i32 %a, -2147483648
1140+ // %e = or i32 %sign, %c
1141+ if (match (magnOp, m_Shl (m_Value (v), m_ConstantInt (shlBits))) && v == signOp && (shlBits % numBits) == 0 &&
1142+ shlBits == (i * numBits) && (cSignMask << shlBits) == signMask && (cMagnMask << shlBits) == magnMask &&
1143+ lsb == shlBits) {
1144+ toPack[i] = std::make_pair (v, 0 );
1145+ continue ;
1146+ }
1147+ // %b = and i32 %a, 127
1148+ // %c = lshr i32 %a, 24
1149+ // %sign = and i32 %c, 128
1150+ // %e = or i32 %sign, %b
1151+ if (match (signOp, m_LShr (m_Value (v), m_ConstantInt (shrBits))) && v == magnOp && shrBits == (32 - numBits) &&
1152+ cSignMask == signMask && cMagnMask == magnMask && lsb == 0 ) {
1153+ toPack[i] = std::make_pair (v, 0 );
1154+ continue ;
1155+ }
1156+ }
1157+ uint64_t andMask = 0 ;
1158+ if (match (v, m_And (m_Value (lOp), m_ConstantInt (andMask))) && andMask == BITMASK (numBits) && lsb == 0 ) {
1159+ toPack[i] = std::make_pair (lOp, 0 );
1160+ continue ;
1161+ }
1162+ if (lsb > 0 ) {
1163+ return false ;
1164+ }
1165+ }
1166+
1167+ // Create the packing sequence that is matched in the PatternMatch later.
1168+ Type *elemTy = Builder->getIntNTy (numBits);
1169+ Value *packed = PoisonValue::get (IGCLLVM::FixedVectorType::get (elemTy, packedVecSize));
1170+ for (uint32_t i = 0 ; i < packedVecSize; ++i) {
1171+ auto [elem, lsb] = toPack[i];
1172+ IGC_ASSERT (lsb == 0 );
1173+ elem = Builder->CreateTrunc (elem, elemTy);
1174+ packed = Builder->CreateInsertElement (packed, elem, Builder->getInt32 (i));
1175+ }
1176+ packed = Builder->CreateBitCast (packed, inst->getType ());
1177+ inst->replaceAllUsesWith (packed);
1178+ inst->eraseFromParent ();
1179+ return true ;
1180+ }
1181+
10031182FunctionPass *IGC::createGenIRLowerPass () { return new GenIRLowering (); }
10041183
10051184// Register pass to igc-opt
0 commit comments