@@ -7435,14 +7435,17 @@ void EmitPass::EmitGenIntrinsicMessage(llvm::GenIntrinsicInst* inst)
74357435 case GenISAIntrinsic::GenISA_WaveBallot:
74367436 emitWaveBallot(inst);
74377437 break;
7438+ case GenISAIntrinsic::GenISA_WaveInverseBallot:
7439+ emitWaveInverseBallot(inst);
7440+ break;
74387441 case GenISAIntrinsic::GenISA_WaveShuffleIndex:
74397442 emitSimdShuffle(inst);
74407443 break;
74417444 case GenISAIntrinsic::GenISA_WavePrefix:
7442- emitWavePrefix(inst);
7445+ emitWavePrefix(cast<WavePrefixIntrinsic>( inst) );
74437446 break;
74447447 case GenISAIntrinsic::GenISA_QuadPrefix:
7445- emitWavePrefix( inst, true );
7448+ emitQuadPrefix(cast<QuadPrefixIntrinsic>( inst) );
74467449 break;
74477450 case GenISAIntrinsic::GenISA_WaveAll:
74487451 emitWaveAll(inst);
@@ -9981,16 +9984,21 @@ void EmitPass::emitReductionAll(
99819984 m_encoder->Push();
99829985}
99839986
9984- void EmitPass::emitPreOrPostFixOp(e_opcode op, uint64_t identityValue, VISA_Type type, bool negateSrc, CVariable* pSrc, CVariable* pSrcsArr[2], bool isPrefix, bool isQuad)
9987+ void EmitPass::emitPreOrPostFixOp(
9988+ e_opcode op, uint64_t identityValue, VISA_Type type, bool negateSrc,
9989+ CVariable* pSrc, CVariable* pSrcsArr[2], CVariable *Flag,
9990+ bool isPrefix, bool isQuad)
99859991{
9986- // This is to handle cases when not all lanes are enabled. In that case we fill the lanes with 0.
9987-
99889992 if (m_currShader->m_Platform->doScalar64bScan() && CEncoder::GetCISADataTypeSize(type) == 8 && !isQuad)
99899993 {
9990- emitPreOrPostFixOpScalar(op, identityValue, type, negateSrc, pSrc, pSrcsArr, isPrefix);
9994+ emitPreOrPostFixOpScalar(
9995+ op, identityValue, type, negateSrc,
9996+ pSrc, pSrcsArr, Flag,
9997+ isPrefix);
99919998 return;
99929999 }
999310000
10001+ // This is to handle cases when not all lanes are enabled. In that case we fill the lanes with 0.
999410002 bool isSimd32 = (m_currShader->m_dispatchSize == SIMDMode::SIMD32);
999510003 int counter = 1;
999610004 if (isSimd32)
@@ -10004,18 +10012,20 @@ void EmitPass::emitPreOrPostFixOp(e_opcode op, uint64_t identityValue, VISA_Type
1000410012 IGC::EALIGN_GRF,
1000510013 false);
1000610014
10007- // Set the GRF to 0 with no mask. This will set all the registers to 0
10015+ // Set the GRF to <identity> with no mask. This will set all the registers to <identity>
1000810016 CVariable* pIdentityValue = m_currShader->ImmToVariable(identityValue, type);
1000910017 m_encoder->SetNoMask();
1001010018 m_encoder->Copy(pSrcCopy, pIdentityValue);
1001110019 m_encoder->Push();
1001210020
10013- // Now copy the src with a mask so the disabled lanes still keep their 0
10021+ // Now copy the src with a mask so the disabled lanes still keep their <identity>
1001410022 if (negateSrc)
1001510023 {
1001610024 m_encoder->SetSrcModifier(0, EMOD_NEG);
1001710025 }
1001810026 m_encoder->SetSecondHalf(i == 1);
10027+ if (Flag)
10028+ m_encoder->SetPredicate(Flag);
1001910029 m_encoder->Copy(pSrcCopy, pSrc);
1002010030 m_encoder->Push();
1002110031
@@ -10063,8 +10073,9 @@ void EmitPass::emitPreOrPostFixOp(e_opcode op, uint64_t identityValue, VISA_Type
1006310073 {
1006410074 /*
1006510075 Copy the adjacent elements.
10066- for example: r10 be the register
10067- ____ ____ ____ ____
10076+ for example: let r10 be the register
10077+ Assume we are performing addition for this example
10078+ ____ ____ ____ ____
1006810079 __|____|____|____|____|____|____|____|_
1006910080 | 7 | 6 | 5 | 4 | 9 | 5 | 3 | 2 |
1007010081 ---------------------------------------
@@ -10095,10 +10106,10 @@ void EmitPass::emitPreOrPostFixOp(e_opcode op, uint64_t identityValue, VISA_Type
1009510106 }
1009610107
1009710108 /*
10098- ____ ____
10099- _______|____|________________|____|______ ___________________________________________
10109+ ____ ____
10110+ _______|____|________________|____|______ ___________________________________________
1010010111 | 13 | 6 | 9 | 4 | 14 | 5 | 5 | 2 | ==> | 13 | 15 | 9 | 4 | 14 | 10 | 5 | 2 |
10101- ----------------------------------------- -------------------------------------------
10112+ ----------------------------------------- -------------------------------------------
1010210113 */
1010310114 // Now we have a weird copy happening. This will be done by SIMD 2 instructions.
1010410115
@@ -10127,7 +10138,7 @@ void EmitPass::emitPreOrPostFixOp(e_opcode op, uint64_t identityValue, VISA_Type
1012710138 }
1012810139
1012910140 /*
10130- ___________ ___________
10141+ ___________ ___________
1013110142 __|___________|_________|___________|______ ___________________________________________
1013210143 | 13 | 15 | 9 | 4 | 14 | 10 | 5 | 2 | ==> | 22 | 15 | 9 | 4 | 19 | 10 | 5 | 2 |
1013310144 ------------------------------------------- -------------------------------------------
@@ -10164,21 +10175,21 @@ void EmitPass::emitPreOrPostFixOp(e_opcode op, uint64_t identityValue, VISA_Type
1016410175 }
1016510176
1016610177 /*
10167- ____
10178+ ____
1016810179 __________________|____|_________________ ____________________________________________
1016910180 | 22 | 15 | 9 | 4 | 19 | 10 | 5 | 2 | ==> | 22 | 15 | 9 | 23 | 19 | 10 | 5 | 2 |
1017010181 ----------------------------------------- --------------------------------------------
10171- _________
10182+ _________
1017210183 _____________|_________|_________________ _____________________________________________
1017310184 | 22 | 15 | 9 | 4 | 19 | 10 | 5 | 2 | ==> | 22 | 15 | 28 | 23 | 19 | 10 | 5 | 2 |
1017410185 ----------------------------------------- ---------------------------------------------
1017510186
10176- ______________
10187+ ______________
1017710188 ________|______________|_________________ _____________________________________________
1017810189 | 22 | 15 | 9 | 4 | 19 | 10 | 5 | 2 | ==> | 22 | 34 | 28 | 23 | 19 | 10 | 5 | 2 |
1017910190 ----------------------------------------- ---------------------------------------------
1018010191
10181- ____________________
10192+ ____________________
1018210193 __|____________________|_________________ _____________________________________________
1018310194 | 22 | 15 | 9 | 4 | 19 | 10 | 5 | 2 | ==> | 41 | 34 | 28 | 23 | 19 | 10 | 5 | 2 |
1018410195 ----------------------------------------- ---------------------------------------------
@@ -10239,6 +10250,7 @@ void EmitPass::emitPreOrPostFixOpScalar(
1023910250 bool negateSrc,
1024010251 CVariable* src,
1024110252 CVariable* result[2],
10253+ CVariable* Flag,
1024210254 bool isPrefix)
1024310255{
1024410256 // This is to handle cases when not all lanes are enabled. In that case we fill the lanes with 0.
@@ -10259,19 +10271,21 @@ void EmitPass::emitPreOrPostFixOpScalar(
1025910271 IGC::EALIGN_GRF,
1026010272 false);
1026110273
10262- // Set the GRF to 0 with no mask. This will set all the registers to 0
10274+ // Set the GRF to <identity> with no mask. This will set all the registers to <identity>
1026310275 CVariable* pIdentityValue = m_currShader->ImmToVariable(identityValue, type);
1026410276 m_encoder->SetSecondHalf(i == 1);
1026510277 m_encoder->SetNoMask();
1026610278 m_encoder->Copy(pSrcCopy[i], pIdentityValue);
1026710279 m_encoder->Push();
1026810280
10269- // Now copy the src with a mask so the disabled lanes still keep their 0
10281+ // Now copy the src with a mask so the disabled lanes still keep their <identity>
1027010282 if (negateSrc)
1027110283 {
1027210284 m_encoder->SetSrcModifier(0, EMOD_NEG);
1027310285 }
1027410286 m_encoder->SetSecondHalf(i == 1);
10287+ if (Flag)
10288+ m_encoder->SetPredicate(Flag);
1027510289 m_encoder->Copy(pSrcCopy[i], src);
1027610290 m_encoder->Push();
1027710291
@@ -14326,6 +14340,33 @@ void EmitPass::emitWaveBallot(llvm::GenIntrinsicInst* inst)
1432614340 }
1432714341}
1432814342
14343+ void EmitPass::emitWaveInverseBallot(llvm::GenIntrinsicInst* inst)
14344+ {
14345+ CVariable *Mask = GetSymbol(inst->getOperand(0));
14346+
14347+ if (Mask->IsUniform())
14348+ {
14349+ if (m_encoder->IsSecondHalf())
14350+ return;
14351+
14352+ m_encoder->SetP(m_destination, Mask);
14353+ return;
14354+ }
14355+
14356+ // The uniform case should by far be the most common. Otherwise,
14357+ // fall back and compute:
14358+ //
14359+ // (val & (1 << id)) != 0
14360+ CVariable *Temp = m_currShader->GetNewVariable(
14361+ numLanes(m_currShader->m_SIMDSize), ISA_TYPE_UD, EALIGN_GRF);
14362+
14363+ m_currShader->GetSimdOffsetBase(Temp);
14364+ m_encoder->Shl(Temp, m_currShader->ImmToVariable(1, ISA_TYPE_UD), Temp);
14365+ m_encoder->And(Temp, Mask, Temp);
14366+ m_encoder->Cmp(EPREDICATE_NE,
14367+ m_destination, Temp, m_currShader->ImmToVariable(0, ISA_TYPE_UD));
14368+ }
14369+
1432914370static void GetReductionOp(WaveOps op, Type* opndTy, uint64_t& identity, e_opcode& opcode, VISA_Type& type)
1433014371{
1433114372 auto getISAType = [](Type* ty, bool isSigned = true)
@@ -14468,17 +14509,49 @@ static void GetReductionOp(WaveOps op, Type* opndTy, uint64_t& identity, e_opcod
1446814509 }
1446914510}
1447014511
14471- void EmitPass::emitWavePrefix(llvm::GenIntrinsicInst* inst, bool isQuad)
14512+ void EmitPass::emitWavePrefix(WavePrefixIntrinsic* I)
14513+ {
14514+ Value *Mask = I->getMask();
14515+ if (auto *CI = dyn_cast<ConstantInt>(Mask))
14516+ {
14517+ // If the mask is all set, then we just pass a null
14518+ // mask to emitScan() indicating we don't want to
14519+ // emit any predication.
14520+ if (CI->isAllOnesValue())
14521+ Mask = nullptr;
14522+ }
14523+ emitScan(
14524+ I->getSrc(), I->getOpKind(), I->isInclusiveScan(), Mask, false);
14525+ }
14526+
14527+ void EmitPass::emitQuadPrefix(QuadPrefixIntrinsic* I)
14528+ {
14529+ emitScan(
14530+ I->getSrc(), I->getOpKind(), I->isInclusiveScan(), nullptr, true);
14531+ }
14532+
14533+ void EmitPass::emitScan(
14534+ Value *Src, IGC::WaveOps Op,
14535+ bool isInclusiveScan, Value *Mask, bool isQuad)
1447214536{
14473- WaveOps op = static_cast<WaveOps>(cast<llvm::ConstantInt>(inst->getOperand(1))->getZExtValue());
14474- bool isInclusiveScan = cast<llvm::ConstantInt>(inst->getOperand(2))->getZExtValue() != 0;
1447514537 VISA_Type type;
1447614538 e_opcode opCode;
1447714539 uint64_t identity = 0;
14478- GetReductionOp(op, inst->getOperand(0) ->getType(), identity, opCode, type);
14479- CVariable* src = GetSymbol(inst->getOperand(0) );
14540+ GetReductionOp(Op, Src ->getType(), identity, opCode, type);
14541+ CVariable* src = GetSymbol(Src );
1448014542 CVariable *dst[2] = { nullptr, nullptr };
14481- emitPreOrPostFixOp(opCode, identity, type, false, src, dst, !isInclusiveScan, isQuad);
14543+ CVariable *Flag = Mask ? GetSymbol(Mask) : nullptr;
14544+
14545+ emitPreOrPostFixOp(
14546+ opCode, identity, type,
14547+ false, src, dst, Flag,
14548+ !isInclusiveScan, isQuad);
14549+
14550+ // Now that we've computed the result in temporary registers,
14551+ // make sure we only write the results to lanes participating in the
14552+ // scan as specified by 'mask'.
14553+ if (Flag)
14554+ m_encoder->SetPredicate(Flag);
1448214555 m_encoder->Copy(m_destination, dst[0]);
1448314556 if (m_currShader->m_dispatchSize == SIMDMode::SIMD32)
1448414557 {
0 commit comments