Skip to content

Commit bad3bca

Browse files
pkwasnie-inteligcbot
authored andcommitted
simplify insertelement for large vectors
LLVM's instcombine pass (method InstCombinerImpl::SimplifyDemandedVectorElts) is able to set input on first insertelement to poison if all indices in vector are inserted (all indices are overwritten). This optimization has a hardcoded limit on vector size in LLVM 15 and 16, which is too short for compute workloads. Manually optimize such cases in IGC.
1 parent 78e8d97 commit bad3bca

File tree

3 files changed

+1101
-0
lines changed

3 files changed

+1101
-0
lines changed

IGC/Compiler/CustomSafeOptPass.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2160,6 +2160,53 @@ void CustomSafeOptPass::dp4WithIdentityMatrix(ExtractElementInst &I) {
21602160
addI[2]->replaceAllUsesWith(builder.CreateNeg(sel3));
21612161
}
21622162

2163+
// Check if a new vector is constructed, and if so, replace operand in first insert to undef. This optimization is
2164+
// similar to LLVM instcombine pass (InstCombinerImpl::SimplifyDemandedVectorElts). LLVM optimization is limited to
2165+
// short vectors, where IGC must support larger vectors (16, 32 elements).
2166+
void CustomSafeOptPass::visitInsertElementInst(llvm::InsertElementInst &I) {
2167+
using namespace llvm::PatternMatch;
2168+
2169+
auto *VType = dyn_cast<IGCLLVM::FixedVectorType>(I.getType());
2170+
if (!VType)
2171+
return;
2172+
2173+
auto VWidth = VType->getNumElements();
2174+
auto ElSize = VType->getElementType()->getPrimitiveSizeInBits();
2175+
2176+
// Optimize only vectors typical for DPAS src.
2177+
bool ValidVector = iSTD::IsPowerOfTwo(VWidth) &&
2178+
((ElSize == 8 && VWidth <= 32) || (ElSize == 16 && VWidth <= 16) || (ElSize == 32 && VWidth <= 8));
2179+
if (!ValidVector)
2180+
return;
2181+
2182+
// Pattern is matched bottom-up, starting from last IE in chain. Instructions are visited top-down, exit early if this
2183+
// is not the last IE in chain.
2184+
if (I.hasOneUse() && isa<InsertElementInst>(I.use_begin()->getUser()))
2185+
return;
2186+
2187+
Instruction *CurrentInst = &I, *NextInst = nullptr;
2188+
uint64_t Index = 0;
2189+
APInt VisitedMask(APInt::getZero(VWidth));
2190+
2191+
while (true) {
2192+
2193+
if (!match(CurrentInst, m_InsertElt(m_Instruction(NextInst), m_Value(), m_ConstantInt(Index))))
2194+
return;
2195+
2196+
VisitedMask.setBit(Index);
2197+
2198+
if (NextInst->hasOneUse() && NextInst->getOpcode() == Instruction::InsertElement) {
2199+
CurrentInst = NextInst;
2200+
} else {
2201+
if (VisitedMask.isAllOnesValue()) {
2202+
// All elements are inserted, so input vector is fully overwritten.
2203+
CurrentInst->setOperand(0, PoisonValue::get(VType));
2204+
}
2205+
return;
2206+
}
2207+
}
2208+
}
2209+
21632210
void CustomSafeOptPass::visitExtractElementInst(ExtractElementInst &I) {
21642211
// convert:
21652212
// %1 = lshr i32 %0, 16,

IGC/Compiler/CustomSafeOptPass.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class CustomSafeOptPass : public llvm::FunctionPass, public llvm::InstVisitor<Cu
5858
void visitMulH(llvm::CallInst *inst, bool isSigned);
5959
void visitFPToUIInst(llvm::FPToUIInst &FPUII);
6060
void visitFPTruncInst(llvm::FPTruncInst &I);
61+
void visitInsertElementInst(llvm::InsertElementInst &I);
6162
void visitExtractElementInst(llvm::ExtractElementInst &I);
6263
void visitLdptr(llvm::SamplerLoadIntrinsic *inst);
6364
void visitLdRawVec(llvm::CallInst *inst);

0 commit comments

Comments
 (0)