Skip to content

Commit 0b8394d

Browse files
esukhovigcbot
authored andcommitted
Vectorizer Update fmul instruction vectorized
Vectorizer now can support vector emission of fmul instructions. Implemented for triton flash attention kernel.
1 parent 7bbc6f6 commit 0b8394d

File tree

8 files changed

+954
-227
lines changed

8 files changed

+954
-227
lines changed

IGC/Compiler/CISACodeGen/EmitVISAPass.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4330,6 +4330,7 @@ void EmitPass::EmitGenericPointersCmp(llvm::Instruction* inst,
43304330

43314331
void EmitPass::BinaryUnary(llvm::Instruction* inst, const SSource source[2], const DstModifier& modifier)
43324332
{
4333+
43334334
switch (inst->getOpcode())
43344335
{
43354336
case Instruction::FCmp:
@@ -4361,6 +4362,9 @@ void EmitPass::BinaryUnary(llvm::Instruction* inst, const SSource source[2], con
43614362
case Instruction::Mul:
43624363
Mul(source, modifier);
43634364
break;
4365+
case Instruction::FMul:
4366+
Mul(source, modifier);
4367+
break;
43644368
case Instruction::Call:
43654369
EmitAluIntrinsic(cast<CallInst>(inst), source, modifier);
43664370
break;
@@ -4572,6 +4576,15 @@ void EmitPass::Mul64(CVariable* dst, CVariable* src[2], SIMDMode simdMode, bool
45724576
m_encoder->Push();
45734577
}
45744578

4579+
static unsigned int getVectorSize(Instruction *I) {
4580+
IGCLLVM::FixedVectorType *VecType =
4581+
llvm::dyn_cast<IGCLLVM::FixedVectorType>(I->getType());
4582+
if (!VecType)
4583+
return 0;
4584+
unsigned int NumElements = VecType->getNumElements();
4585+
return NumElements;
4586+
}
4587+
45754588
void EmitPass::Mul(const SSource sources[2], const DstModifier& modifier)
45764589
{
45774590
CVariable* src[2];
@@ -4580,6 +4593,28 @@ void EmitPass::Mul(const SSource sources[2], const DstModifier& modifier)
45804593
src[i] = GetSrcVariable(sources[i]);
45814594
}
45824595

4596+
if (IGC_IS_FLAG_ENABLED(EnableVectorEmitter) && sources[0].value->getType()->isVectorTy() && sources[1].value->getType()->isVectorTy()) {
4597+
4598+
unsigned int VectorSize = 0;
4599+
if (llvm::isa<Instruction>(sources[0].value))
4600+
VectorSize = getVectorSize(llvm::cast<Instruction>(sources[0].value));
4601+
4602+
for (unsigned int i = 0; i < VectorSize; ++i) {
4603+
SetSourceModifiers(0, sources[0]);
4604+
SetSourceModifiers(1, sources[1]);
4605+
4606+
if (src[0]->IsUniform()) { m_encoder->SetSrcSubReg(0, i); }
4607+
else m_encoder->SetSrcSubVar(0, i);
4608+
if (src[1]->IsUniform()) { m_encoder->SetSrcSubReg(1, i); }
4609+
else m_encoder->SetSrcSubVar(1, i);
4610+
4611+
m_encoder->SetDstSubVar(i);
4612+
m_encoder->Mul(m_destination, src[0], src[1]);
4613+
m_encoder->Push();
4614+
}
4615+
return;
4616+
}
4617+
45834618
// Only i64 muls need special handling, otherwise go back to standard flow
45844619
VISA_Type srcType = src[0]->GetType();
45854620
if (srcType != ISA_TYPE_Q && srcType != ISA_TYPE_UQ)

0 commit comments

Comments
 (0)