@@ -9194,6 +9194,9 @@ void EmitPass::EmitGenIntrinsicMessage(llvm::GenIntrinsicInst* inst)
91949194 case GenISAIntrinsic::GenISA_QuadPrefix:
91959195 emitQuadPrefix(cast<QuadPrefixIntrinsic>(inst));
91969196 break;
9197+ case GenISAIntrinsic::GenISA_WaveClusteredPrefix:
9198+ emitWaveClusteredPrefix(inst);
9199+ break;
91979200 case GenISAIntrinsic::GenISA_WaveAll:
91989201 emitWaveAll(inst);
91999202 break;
@@ -14431,8 +14434,18 @@ void EmitPass::emitReductionClusteredInterleave(const e_opcode op, const uint64_
1443114434void EmitPass::emitPreOrPostFixOp(
1443214435 e_opcode op, uint64_t identityValue, VISA_Type type, bool negateSrc,
1443314436 CVariable* pSrc, CVariable* pSrcsArr[2], CVariable* Flag,
14434- bool isPrefix, bool isQuad)
14435- {
14437+ bool isPrefix, bool isQuad, int clusterSize)
14438+ {
14439+ // TODO Arguments isQuad and clusterSize have similar function: both split subgroup into
14440+ // smaller sets of lanes processed separately. isQuad could be considered clusterSize == 4,
14441+ // but there is a significant difference in implementation: when shifting input by one lane
14442+ // to the right for exclusive scan (isPrefix == true), isQuad inserts identity value only
14443+ // to the first lane in subgroup, where clusterSize == 8/16 inserts identity value to the
14444+ // first lane of each cluster.
14445+ //
14446+ // isQuad/clusterSize could be replaced with one argument, but the code must be refactored
14447+ // to not break QuadPrefix intrinsic.
14448+
1443614449 const bool isInt64Mul = ScanReduceIsInt64Mul(op, type);
1443714450 const bool int64EmulationNeeded = ScanReduceIsInt64EmulationNeeded(op, type);
1443814451
@@ -14442,12 +14455,12 @@ void EmitPass::emitPreOrPostFixOp(
1444214455 emitPreOrPostFixOpScalar(
1444314456 op, identityValue, type, negateSrc,
1444414457 pSrc, pSrcsArr, Flag,
14445- isPrefix);
14458+ isPrefix, clusterSize );
1444614459 return;
1444714460 }
1444814461
14449- bool isSimd32 = m_currShader->m_numberInstance == 2;
14450- int counter = isSimd32 ? 2 : 1;
14462+ bool isSimd32AsTwoInstances = m_currShader->m_numberInstance == 2;
14463+ int counter = isSimd32AsTwoInstances ? 2 : 1;
1445114464
1445214465 CVariable* maskedSrc[2] = { 0 };
1445314466 for (int i = 0; i < counter; ++i)
@@ -14466,7 +14479,9 @@ void EmitPass::emitPreOrPostFixOp(
1446614479 // Copy identity
1446714480 m_encoder->SetSimdSize(SIMDMode::SIMD1);
1446814481 m_encoder->SetNoMask();
14469- if (i == 0)
14482+ // Before shift, insert identity value to the first lane
14483+ // in subgroup (or cluster).
14484+ if (i == 0 || clusterSize > 0)
1447014485 {
1447114486 CVariable* pIdentityValue = m_currShader->ImmToVariable(identityValue, type);
1447214487 m_encoder->Copy(pSrcCopy, pIdentityValue);
@@ -14496,7 +14511,25 @@ void EmitPass::emitPreOrPostFixOp(
1449614511 }
1449714512 offset += simdsize;
1449814513 }
14514+
14515+ // After shifting the input by one lane, in each cluster that starts in
14516+ // the middle of GRF, set the first lane to the identity value.
14517+ if (clusterSize > 0)
14518+ {
14519+ m_encoder->SetSimdSize(SIMDMode::SIMD1);
14520+ m_encoder->SetNoMask();
14521+ CVariable* pIdentityValue = m_currShader->ImmToVariable(identityValue, type);
14522+
14523+ for (int i = clusterSize; i < pSrcCopy->GetNumberElement(); i += clusterSize)
14524+ {
14525+ m_encoder->SetDstSubReg(i);
14526+ m_encoder->Copy(pSrcCopy, pIdentityValue);
14527+ }
14528+
14529+ m_encoder->Push();
14530+ }
1449914531 }
14532+
1450014533 pSrcsArr[i] = pSrcCopy;
1450114534 }
1450214535
@@ -14593,7 +14626,7 @@ void EmitPass::emitPreOrPostFixOp(
1459314626 }
1459414627 };
1459514628
14596- if (m_currShader->m_dispatchSize == SIMDMode::SIMD32 && !isSimd32 )
14629+ if (m_currShader->m_dispatchSize == SIMDMode::SIMD32 && !isSimd32AsTwoInstances )
1459714630 {
1459814631 // handling the single SIMD32 size case in PVC
1459914632 // the logic is mostly similar to the legacy code sequence below, except that
@@ -14647,6 +14680,12 @@ void EmitPass::emitPreOrPostFixOp(
1464714680 (loop_counter * 8 + 4) /*dst subreg*/, 1 /*dst region*/);
1464814681 }
1464914682
14683+ if (clusterSize == 8)
14684+ {
14685+ // With SIMD8 clusters, stop at SIMD8 prefix.
14686+ return;
14687+ }
14688+
1465014689 // Merge: 2 SIMD8's to get 2 SIMD16 prefix sequence
1465114690 for (uint loop_counter = 0; loop_counter < 2; ++loop_counter)
1465214691 {
@@ -14659,6 +14698,12 @@ void EmitPass::emitPreOrPostFixOp(
1465914698 loop_counter * 16 + 8 /*dst subreg*/, 1 /*dst region*/);
1466014699 }
1466114700
14701+ if (clusterSize == 16)
14702+ {
14703+ // With SIMD16 clusters, stop at SIMD16 prefix.
14704+ return;
14705+ }
14706+
1466214707 // final merge to get 1 SIMD32 prefix sequence and viola!
1466314708 {
1466414709 const uint src0Region[3] = { 0, 1, 0 };
@@ -14783,7 +14828,13 @@ void EmitPass::emitPreOrPostFixOp(
1478314828 (loop_counter * 8 + 4) /*dst subreg*/, 1 /*dst region*/);
1478414829 }
1478514830
14786- if (m_currShader->m_SIMDSize == SIMDMode::SIMD16 || isSimd32)
14831+ if (clusterSize == 8)
14832+ {
14833+ // Stop ALU ops at SIMD8 lanes.
14834+ continue;
14835+ }
14836+
14837+ if (m_currShader->m_SIMDSize == SIMDMode::SIMD16 || isSimd32AsTwoInstances)
1478714838 {
1478814839 // Add the last element of the 1st GRF to all the elements of the 2nd GRF
1478914840 const uint src0Region[3] = { 0, 1, 0 };
@@ -14796,7 +14847,8 @@ void EmitPass::emitPreOrPostFixOp(
1479614847 }
1479714848 }
1479814849
14799- if (isSimd32 && !isQuad)
14850+ bool hasClusters = isQuad || clusterSize > 0;
14851+ if (isSimd32AsTwoInstances && !hasClusters)
1480014852 {
1480114853 // For SIMD32 we need to write the last element of the prev element to the next 16 elements
1480214854 const uint src0Region[3] = { 0, 1, 0 };
@@ -14820,13 +14872,14 @@ void EmitPass::emitPreOrPostFixOpScalar(
1482014872 CVariable* src,
1482114873 CVariable* result[2],
1482214874 CVariable* Flag,
14823- bool isPrefix)
14875+ bool isPrefix,
14876+ int clusterSize)
1482414877{
1482514878 const bool isInt64Mul = ScanReduceIsInt64Mul(op, type);
1482614879 const bool int64EmulationNeeded = ScanReduceIsInt64EmulationNeeded(op, type);
1482714880
14828- bool isSimd32 = m_currShader->m_numberInstance == 2;
14829- int counter = isSimd32 ? 2 : 1;
14881+ bool isSimd32AsTwoInstances = m_currShader->m_numberInstance == 2;
14882+ int counter = isSimd32AsTwoInstances ? 2 : 1;
1483014883 CVariable* pSrcCopy[2] = {};
1483114884 for (int i = 0; i < counter; ++i)
1483214885 {
@@ -14849,7 +14902,7 @@ void EmitPass::emitPreOrPostFixOpScalar(
1484914902 if (isPrefix)
1485014903 {
1485114904 // For case where we need the prefix shift the source by 1 lane.
14852- if (i == 0)
14905+ if (i == 0 || clusterSize == 8 || clusterSize == 16 )
1485314906 {
1485414907 // (W) mov (1) result[0] identity
1485514908 CVariable* pIdentityValue = m_currShader->ImmToVariable(identityValue, type);
@@ -14884,6 +14937,23 @@ void EmitPass::emitPreOrPostFixOpScalar(
1488414937
1488514938 for (int dstIdx = 1; dstIdx < numLanes(m_currShader->m_SIMDSize); ++dstIdx, ++srcIdx)
1488614939 {
14940+ // Scan is done one by one. With clusters, start each cluster with
14941+ // initial value.
14942+ if ((clusterSize == 8 || clusterSize == 16) && dstIdx % clusterSize == 0)
14943+ {
14944+ // For case where we need the prefix, start cluster with
14945+ // identity value.
14946+ if (isPrefix)
14947+ {
14948+ m_encoder->SetSimdSize(SIMDMode::SIMD1);
14949+ m_encoder->SetNoMask();
14950+ m_encoder->SetDstSubReg(dstIdx);
14951+ CVariable* pIdentityValue = m_currShader->ImmToVariable(identityValue, type);
14952+ m_encoder->Copy(result[i], pIdentityValue);
14953+ continue;
14954+ }
14955+ }
14956+
1488714957 // do the scan one by one
1488814958 // (W) op (1) result[dstIdx] srcCopy[srcIdx] result[dstIdx-1]
1488914959 if (!int64EmulationNeeded)
@@ -14924,7 +14994,7 @@ void EmitPass::emitPreOrPostFixOpScalar(
1492414994 m_encoder->SetSecondHalf(false);
1492514995 }
1492614996
14927- if (isSimd32 )
14997+ if (isSimd32AsTwoInstances && !clusterSize )
1492814998 {
1492914999 const SIMDMode simd = SIMDMode::SIMD16;
1493015000
@@ -22157,6 +22227,68 @@ void EmitPass::emitScan(
2215722227 m_encoder->Push();
2215822228}
2215922229
22230+ void EmitPass::emitWaveClusteredPrefix(GenIntrinsicInst* I)
22231+ {
22232+ auto helperLanes = int_cast<int>(cast<ConstantInt>(I->getArgOperand(3))->getSExtValue());
22233+ bool disableHelperLanes = (helperLanes == 2);
22234+
22235+ IGC_ASSERT_MESSAGE(isa<llvm::ConstantInt>(I->getOperand(2)), "Unsupported: cluster size must be constant");
22236+ const unsigned int clusterSize = int_cast<uint32_t>(cast<llvm::ConstantInt>(I->getOperand(2))->getZExtValue());
22237+
22238+ IGC_ASSERT_MESSAGE(clusterSize <= numLanes(m_currShader->m_dispatchSize), "Cluster size must be smaller or equal to SIMD");
22239+ IGC_ASSERT_MESSAGE(clusterSize == 8 || clusterSize == 16 || clusterSize == 32, "Cluster size must be 8/16/32");
22240+
22241+ IGC::WaveOps Op = static_cast<IGC::WaveOps>(I->getImm64Operand(1));
22242+ IGC_ASSERT_MESSAGE(Op == IGC::WaveOps::SUM || Op == IGC::WaveOps::FSUM, "Unsupported op type");
22243+
22244+ if (disableHelperLanes)
22245+ {
22246+ ForceDMask();
22247+ }
22248+
22249+ Value* Src = I->getOperand(0);
22250+
22251+ if (clusterSize == numLanes(m_currShader->m_dispatchSize))
22252+ {
22253+ // If cluster size is equal to SIMD size, just run normal scan.
22254+ emitScan(Src, Op, false, nullptr, false);
22255+ }
22256+ else
22257+ {
22258+ // Run scan with clusters.
22259+
22260+ VISA_Type type;
22261+ e_opcode opCode;
22262+ uint64_t identity = 0;
22263+ GetReductionOp(Op, Src->getType(), identity, opCode, type);
22264+
22265+ IGC_ASSERT_MESSAGE((CEncoder::GetCISADataTypeSize(type) == 8 && ScanReduceIsInt64EmulationNeeded(opCode, type)) == false,
22266+ "Unsupported: 64b data type");
22267+
22268+ CVariable* src = GetSymbol(Src);
22269+ CVariable* dst[2] = { nullptr, nullptr };
22270+
22271+ emitPreOrPostFixOp(
22272+ opCode, identity, type,
22273+ false, src, dst, nullptr,
22274+ true, false, clusterSize);
22275+
22276+ m_encoder->Copy(m_destination, dst[0]);
22277+ if (m_currShader->m_numberInstance == 2)
22278+ {
22279+ m_encoder->SetSecondHalf(true);
22280+ m_encoder->Copy(m_destination, dst[1]);
22281+ m_encoder->SetSecondHalf(false);
22282+ }
22283+ m_encoder->Push();
22284+ }
22285+
22286+ if (disableHelperLanes)
22287+ {
22288+ ResetVMask();
22289+ }
22290+ }
22291+
2216022292void EmitPass::emitWaveAll(llvm::GenIntrinsicInst* inst)
2216122293{
2216222294 bool disableHelperLanes = int_cast<int>(cast<ConstantInt>(inst->getArgOperand(2))->getSExtValue()) == 2;
0 commit comments