Skip to content

Commit 6caf55d

Browse files
mmereckiigcbot
authored andcommitted
Another pattern match for packing <4 x i8> values.
This PR add detection for one more pattern that packs 4 8-bit integer values into a single 32-bit value.
1 parent b37c5cd commit 6caf55d

File tree

3 files changed

+190
-9
lines changed

3 files changed

+190
-9
lines changed

IGC/Compiler/CISACodeGen/EmitVISAPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2426,6 +2426,7 @@ void EmitPass::EmitPack4i8(const std::array<EOPCODE, 4> &opcodes, const std::arr
24262426
CVariable *src0 = GetSrcVariable(sources0[i]);
24272427
switch (opcodes[i]) {
24282428
case llvm_bitcast:
2429+
case llvm_fptosi:
24292430
m_encoder->Cast(dst, src0);
24302431
break;
24312432
case llvm_min:

IGC/Compiler/CISACodeGen/GenIRLowering.cpp

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ SPDX-License-Identifier: MIT
2121
#include <llvm/Analysis/ScalarEvolution.h>
2222
#include <llvm/Analysis/ScalarEvolutionExpressions.h>
2323
#include <llvm/Analysis/TargetFolder.h>
24+
#include <llvm/Analysis/ValueTracking.h>
2425
#include <llvm/IR/GetElementPtrTypeIterator.h>
26+
#include <llvm/Support/KnownBits.h>
2527
#include <llvm/Transforms/Utils/ScalarEvolutionExpander.h>
2628
#include <llvm/Transforms/Utils/Local.h>
2729
#include "llvmWrapper/IR/Intrinsics.h"
@@ -63,6 +65,7 @@ class GenIRLowering : public FunctionPass {
6365

6466
bool combineFMaxFMin(CallInst *GII, BasicBlock::iterator &BBI) const;
6567
bool combineSelectInst(SelectInst *Sel, BasicBlock::iterator &BBI) const;
68+
bool combinePack4i8Or2i16(Instruction *inst, uint64_t numBits) const;
6669

6770
bool constantFoldFMaxFMin(CallInst *GII, BasicBlock::iterator &BBI) const;
6871
};
@@ -362,6 +365,15 @@ bool GenIRLowering::runOnFunction(Function &F) {
362365
Changed |= combineSelectInst(cast<SelectInst>(Inst), BI);
363366
}
364367
break;
368+
case Instruction::Or:
369+
if (Inst->getType()->isIntegerTy(32)) {
370+
// Detect packing of 4 i8 values and convert to a pattern that is
371+
// matched CodeGenPatternMatch::MatchPack4i8
372+
Changed |= combinePack4i8Or2i16(Inst, 8 /*numBits*/);
373+
// TODO: also detect <2 x i16> packing once PatternMatch is updated
374+
// to packing of 16-bit values.
375+
}
376+
break;
365377
}
366378
}
367379
}
@@ -1000,6 +1012,162 @@ bool GenIRLowering::combineSelectInst(SelectInst *Sel, BasicBlock::iterator &BBI
10001012
return false;
10011013
}
10021014

1015+
////////////////////////////////////////////////////////////////////////////////
1016+
// Detect complex patterns that pack 2 16-bit or 4 8-bit integers into a 32-bit
1017+
// value. Generate equivalent sequence of instructions that is later matched in
1018+
// the CodeGenPatternMatch::MatchPack4i8().
1019+
// Pattern example for <4 x i8> packing:
1020+
// %x1 = and i32 %x, 127
1021+
// %x2 = lshr i32 %x, 24
1022+
// %x3 = and i32 %x2, 128
1023+
// %x4 = or i32 %x3, %x1
1024+
// %y1 = and i32 %y, 127
1025+
// %y2 = lshr i32 %y, 24
1026+
// %y3 = and i32 %y2, 128
1027+
// %y4 = or i32 %y3, %y1
1028+
// %y5 = shl nuw nsw i32 %y4, 8
1029+
// %xy = or i32 %x4, %y5
1030+
// %z1 = and i32 %z, 127
1031+
// %z2 = lshr i32 %z, 24
1032+
// %z3 = and i32 %z2, 128
1033+
// %z4 = or i32 %z3, %z1
1034+
// %z5 = shl nuw nsw i32 %z4, 16
1035+
// %xyz = or i32 %xy, %z5
1036+
// %w1 = shl nsw i32 %w, 24
1037+
// %w2 = and i32 %w1, 2130706432
1038+
// %w3 = and i32 %w, -2147483648
1039+
// %w4 = or i32 %w2, %w3
1040+
// %xyzw = or i32 %xyz, %w4
1041+
// and generate:
1042+
// %0 = trunc i32 %x to i8
1043+
// %1 = insertelement <4 x i8> poison, i8 %0, i32 0
1044+
// %2 = trunc i32 %y to i8
1045+
// %3 = insertelement <4 x i8> %1, i8 %2, i32 1
1046+
// %4 = trunc i32 %z to i8
1047+
// %5 = insertelement <4 x i8> %3, i8 %4, i32 2
1048+
// %6 = trunc i32 %w to i8
1049+
// %7 = insertelement <4 x i8> %5, i8 %6, i32 3
1050+
// %8 = bitcast <4 x i8> %7 to i32
1051+
bool GenIRLowering::combinePack4i8Or2i16(Instruction *inst, uint64_t numBits) const {
1052+
using namespace llvm::PatternMatch;
1053+
1054+
const DataLayout &DL = inst->getModule()->getDataLayout();
1055+
// Vector of 4 or 2 values that will be packed into a single 32-bit value.
1056+
SmallVector<Value *, 4> toPack;
1057+
IGC_ASSERT(numBits == 8 || numBits == 16);
1058+
uint64_t packedVecSize = 32 / numBits;
1059+
toPack.resize(packedVecSize);
1060+
uint64_t cSignMask = QWBIT(numBits - 1);
1061+
uint64_t cMagnMask = BITMASK(numBits - 1);
1062+
SmallVector<std::pair<Value *, uint64_t>, 4> args;
1063+
args.push_back({isa<BitCastInst>(inst) ? inst->getOperand(0) : inst, 0});
1064+
// In the first step traverse the chain of `or` and `shl` instructions
1065+
// and find all elements of the packed vector.
1066+
while (!args.empty()) {
1067+
auto [v, prevShlBits] = args.pop_back_val();
1068+
Value *lOp = nullptr;
1069+
Value *rOp = nullptr;
1070+
1071+
// Detect left shift by multiple of `numBits`. The `shl` operation sets the
1072+
// `index` argument in the corresponding InsertElement instruction in the
1073+
// final packing sequence. This operation can also be viewed as repacking
1074+
// of already packed vector into another packed vector.
1075+
uint64_t shlBits = 0;
1076+
if (match(v, m_Shl(m_Value(lOp), m_ConstantInt(shlBits))) && (shlBits % numBits) == 0) {
1077+
args.push_back({lOp, shlBits + prevShlBits});
1078+
continue;
1079+
}
1080+
// Detect values that fit into `numBits` bits - a single element of
1081+
// the packed vector.
1082+
KnownBits kb = computeKnownBits(v, DL);
1083+
uint32_t nonZeroBits = ~(static_cast<uint32_t>(kb.Zero.getZExtValue()));
1084+
uint32_t lsb = findFirstSet(nonZeroBits);
1085+
uint32_t msb = findLastSet(nonZeroBits);
1086+
if (msb != lsb && (msb / numBits) == (lsb / numBits)) {
1087+
uint32_t idx = (prevShlBits / numBits) + (lsb / numBits);
1088+
if (idx < packedVecSize && toPack[idx] == nullptr) {
1089+
toPack[idx] = v;
1090+
continue;
1091+
}
1092+
}
1093+
// Detect packing of two disjoint values. This `or` operation corresponds
1094+
// to an InsertElement instruction in the final packing sequence.
1095+
if (match(v, m_Or(m_Value(lOp), m_Value(rOp)))) {
1096+
KnownBits kbL = computeKnownBits(lOp, DL);
1097+
KnownBits kbR = computeKnownBits(rOp, DL);
1098+
uint32_t nonZeroBitsL = ~(static_cast<uint32_t>(kbL.Zero.getZExtValue()));
1099+
uint32_t nonZeroBitsR = ~(static_cast<uint32_t>(kbR.Zero.getZExtValue()));
1100+
if ((nonZeroBitsL & nonZeroBitsR) == 0) {
1101+
args.push_back({lOp, prevShlBits});
1102+
args.push_back({rOp, prevShlBits});
1103+
}
1104+
continue;
1105+
}
1106+
if (std::all_of(toPack.begin(), toPack.end(), [](const Value *c) { return c != nullptr; })) {
1107+
break;
1108+
}
1109+
// Unsupported pattern.
1110+
return false;
1111+
}
1112+
if (std::any_of(toPack.begin(), toPack.end(), [](const Value *c) { return c == nullptr; })) {
1113+
return false;
1114+
}
1115+
// In the second step match the pattern that packs sign and magnitude parts
1116+
// and simple masking with `and` instruction.
1117+
for (uint32_t i = 0; i < packedVecSize; ++i) {
1118+
Value *v = toPack[i];
1119+
Value *lOp = nullptr;
1120+
Value *rOp = nullptr;
1121+
uint64_t lMask = 0;
1122+
uint64_t rMask = 0;
1123+
// Match patterns that pack the sign and magnitude parts.
1124+
if (match(v, m_Or(m_And(m_Value(lOp), m_ConstantInt(lMask)), m_And(m_Value(rOp), m_ConstantInt(rMask)))) &&
1125+
(countPopulation(rMask) == 1 || countPopulation(lMask) == 1)) {
1126+
Value *signOp = countPopulation(rMask) == 1 ? rOp : lOp;
1127+
Value *magnOp = countPopulation(rMask) == 1 ? lOp : rOp;
1128+
uint64_t signMask = countPopulation(rMask) == 1 ? rMask : lMask;
1129+
uint64_t magnMask = countPopulation(rMask) == 1 ? lMask : rMask;
1130+
uint64_t shlBits = 0;
1131+
uint64_t shrBits = 0;
1132+
// %b = shl nsw i32 %a, 24
1133+
// %c = and i32 %b, 2130706432
1134+
// %sign = and i32 %a, -2147483648
1135+
// %e = or i32 %sign, %c
1136+
if (match(magnOp, m_Shl(m_Value(v), m_ConstantInt(shlBits))) && v == signOp && (shlBits % numBits) == 0 &&
1137+
shlBits == (i * numBits) && (cSignMask << shlBits) == signMask && (cMagnMask << shlBits) == magnMask) {
1138+
toPack[i] = v;
1139+
continue;
1140+
}
1141+
// %b = and i32 %a, 127
1142+
// %c = lshr i32 %a, 24
1143+
// %sign = and i32 %c, 128
1144+
// %e = or i32 %sign, %b
1145+
if (match(signOp, m_LShr(m_Value(v), m_ConstantInt(shrBits))) && v == magnOp && shrBits == (32 - numBits) &&
1146+
cSignMask == signMask && cMagnMask == magnMask) {
1147+
toPack[i] = v;
1148+
continue;
1149+
}
1150+
}
1151+
uint64_t andMask = 0;
1152+
if (match(v, m_And(m_Value(lOp), m_ConstantInt(andMask))) && (andMask & BITMASK(numBits)) == andMask) {
1153+
toPack[i] = lOp;
1154+
continue;
1155+
}
1156+
}
1157+
1158+
// Create the packing sequence that is matched in the PatternMatch later.
1159+
Type *elemTy = Builder->getIntNTy(numBits);
1160+
Value *packed = PoisonValue::get(IGCLLVM::FixedVectorType::get(elemTy, packedVecSize));
1161+
for (uint32_t i = 0; i < packedVecSize; ++i) {
1162+
Value *elem = Builder->CreateTrunc(toPack[i], elemTy);
1163+
packed = Builder->CreateInsertElement(packed, elem, Builder->getInt32(i));
1164+
}
1165+
packed = Builder->CreateBitCast(packed, inst->getType());
1166+
inst->replaceAllUsesWith(packed);
1167+
inst->eraseFromParent();
1168+
return true;
1169+
}
1170+
10031171
FunctionPass *IGC::createGenIRLowerPass() { return new GenIRLowering(); }
10041172

10051173
// Register pass to igc-opt

IGC/Compiler/CISACodeGen/PatternMatchPass.cpp

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2842,22 +2842,27 @@ bool CodeGenPatternMatch::MatchPack4i8(Instruction &I) {
28422842
}
28432843
return false;
28442844
};
2845-
// Lambda matches clamp(x, 0, 127) pattern.
2845+
// Lambda matches clamp(x, MIN, MAX) pattern.
28462846
// If the pattern is found `x` is returned in the `clampedVal` reference.
2847-
auto MatchClamp0_127 = [&MatchMinMaxWithImm](Value *v, Value *&clampedVal) -> bool {
2847+
auto MatchClampWithImm = [&MatchMinMaxWithImm](Value *v, Value *&clampedVal, uint32_t minVal, uint32_t maxVal) -> bool {
28482848
bool matchMin = true;
28492849
bool matchMax = false;
28502850
Value *src[2];
28512851
// Match either of:
2852-
// v = min(max(x, 0), 127)
2853-
// v = max(min(x, 127), 0)
2854-
if ((MatchMinMaxWithImm(v, 127, matchMin, src[0]) && MatchMinMaxWithImm(src[0], 0, matchMax, src[1])) ||
2855-
(MatchMinMaxWithImm(v, 0, matchMax, src[0]) && MatchMinMaxWithImm(src[0], 127, matchMin, src[1]))) {
2852+
// v = min(max(x, MIN), MAX)
2853+
// v = max(min(x, MIN), MAX)
2854+
if ((MatchMinMaxWithImm(v, maxVal, matchMin, src[0]) && MatchMinMaxWithImm(src[0], minVal, matchMax, src[1])) ||
2855+
(MatchMinMaxWithImm(v, minVal, matchMax, src[0]) && MatchMinMaxWithImm(src[0], maxVal, matchMin, src[1]))) {
28562856
clampedVal = src[1];
28572857
return true;
28582858
}
28592859
return false;
28602860
};
2861+
// Lambda matches clamp(x, 0, 127) pattern.
2862+
// If the pattern is found `x` is returned in the `clampedVal` reference.
2863+
auto MatchClamp0_127 = [&MatchClampWithImm](Value *v, Value *&clampedVal) -> bool {
2864+
return MatchClampWithImm(v, clampedVal, 0, 127);
2865+
};
28612866

28622867
EOPCODE opcodes[4] = {};
28632868
Value *sources0[4] = {};
@@ -2902,16 +2907,23 @@ bool CodeGenPatternMatch::MatchPack4i8(Instruction &I) {
29022907
}
29032908
if (elemsFound == 4) {
29042909
// Match pattern 2
2905-
// Match clamping of values to 0..127 range, e.g.:
2906-
// %x1 = max i32 %x0, 0
2907-
// %x2 = min i32 %x1, 127
29082910
for (uint32_t i = 0; i < 4; ++i) {
29092911
Value *srcToSat;
2912+
// Match clamping of values to 0..127 range, e.g.:
2913+
// %x1 = max i32 %x0, 0
2914+
// %x2 = min i32 %x1, 127
29102915
if (MatchClamp0_127(sources0[i], srcToSat)) {
29112916
opcodes[i] = llvm_max;
29122917
sources0[i] = srcToSat;
29132918
sources1[i] = ConstantInt::get(srcToSat->getType(), 0);
29142919
isSat[i] = true;
2920+
// Match clamping of values to -128..127 range, e.g.:
2921+
// %x1 = max i32 %x0, -128
2922+
// %x2 = min i32 %x1, 127
2923+
} else if (MatchClampWithImm(sources0[i], srcToSat, -128, 127)) {
2924+
opcodes[i] = llvm_fptosi;
2925+
sources0[i] = srcToSat;
2926+
isSat[i] = true;
29152927
}
29162928
}
29172929
}

0 commit comments

Comments
 (0)