Skip to content

Commit e758184

Browse files
pkwasnie-inteligcbot
authored andcommitted
New GenISA intrinsic: WaveClusteredInterleave
Adds new GenISA intrinsic WaveClusteredInterleave that combines two wave reductions: WaveClustered and WaveInterleave. Subgroup is split into clusters (like WaveClustered), and then each cluster does interleaved reduction (WaveInterleave). Change includes a pattern match for reduction implemented with subgroup shuffles.
1 parent 427a492 commit e758184

File tree

14 files changed

+771
-97
lines changed

14 files changed

+771
-97
lines changed

IGC/Compiler/CISACodeGen/CheckInstrTypes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ void CheckInstrTypes::visitCallInst(CallInst& C)
343343
case GenISAIntrinsic::GenISA_WavePrefix:
344344
case GenISAIntrinsic::GenISA_WaveClustered:
345345
case GenISAIntrinsic::GenISA_WaveInterleave:
346+
case GenISAIntrinsic::GenISA_WaveClusteredInterleave:
346347
case GenISAIntrinsic::GenISA_QuadPrefix:
347348
case GenISAIntrinsic::GenISA_simdShuffleDown:
348349
case GenISAIntrinsic::GenISA_simdShuffleXor:

IGC/Compiler/CISACodeGen/EmitVISAPass.cpp

Lines changed: 167 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8980,6 +8980,9 @@ void EmitPass::EmitGenIntrinsicMessage(llvm::GenIntrinsicInst* inst)
89808980
case GenISAIntrinsic::GenISA_WaveClustered:
89818981
emitWaveClustered(inst);
89828982
break;
8983+
case GenISAIntrinsic::GenISA_WaveClusteredInterleave:
8984+
emitWaveClusteredInterleave(inst);
8985+
break;
89838986
case GenISAIntrinsic::GenISA_dp4a_ss:
89848987
case GenISAIntrinsic::GenISA_dp4a_uu:
89858988
case GenISAIntrinsic::GenISA_dp4a_su:
@@ -13802,6 +13805,8 @@ void EmitPass::emitReductionClustered(const e_opcode op, const uint64_t identity
1380213805
}
1380313806
}
1380413807

13808+
// Emits interleave reduction, first preparing the input data. This guarantees to produce
13809+
// correct result even if not all lanes are active.
1380513810
void EmitPass::emitReductionInterleave(const e_opcode op, const uint64_t identityValue, const VISA_Type type,
1380613811
const bool negate, const unsigned int step, CVariable* const src, CVariable* const dst)
1380713812
{
@@ -13819,17 +13824,31 @@ void EmitPass::emitReductionInterleave(const e_opcode op, const uint64_t identit
1381913824

1382013825
CVariable* srcH1 = ScanReducePrepareSrc(type, identityValue, negate, false /* secondHalf */,
1382113826
src, nullptr /* dst */);
13822-
CVariable* temp = srcH1;
13827+
13828+
CVariable* srcH2 = nullptr;
13829+
if (firstStep == 16 && m_currShader->m_numberInstance > 1)
13830+
{
13831+
srcH2 = ScanReducePrepareSrc(type, identityValue, negate, true /* secondHalf */,
13832+
src, nullptr /* dst */);
13833+
}
13834+
13835+
emitReductionInterleave(op, type, m_currShader->m_SIMDSize, step, false, srcH1, srcH2, dst);
13836+
}
13837+
13838+
// Directly emits interleave reduction on input data, without preparing the input.
13839+
void EmitPass::emitReductionInterleave(const e_opcode op, const VISA_Type type, const SIMDMode simd,
13840+
const unsigned int step, const bool noMaskBroadcast, CVariable* const src1, CVariable* const src2, CVariable* const dst)
13841+
{
13842+
const uint16_t firstStep = m_currShader->m_numberInstance * numLanes(simd) / 2;
13843+
13844+
CVariable* temp = src1;
1382313845

1382413846
// Implementation is similar to emitReductionAll(), but we stop reduction before reaching SIMD1.
1382513847
for (unsigned int currentStep = firstStep; currentStep >= step; currentStep >>= 1)
1382613848
{
1382713849
if (currentStep == 16 && m_currShader->m_numberInstance > 1)
1382813850
{
13829-
CVariable* srcH2 = ScanReducePrepareSrc(type, identityValue, negate, true /* secondHalf */,
13830-
src, nullptr /* dst */);
13831-
13832-
temp = ReductionReduceHelper(op, type, SIMDMode::SIMD16, temp, srcH2);
13851+
temp = ReductionReduceHelper(op, type, SIMDMode::SIMD16, temp, src2);
1383313852
}
1383413853
else
1383513854
{
@@ -13838,15 +13857,18 @@ void EmitPass::emitReductionInterleave(const e_opcode op, const uint64_t identit
1383813857
}
1383913858

1384013859
// Broadcast result
13860+
if (noMaskBroadcast)
13861+
m_encoder->SetNoMask();
13862+
1384113863
// For XeHP, for low interleave step, broadcast of 64-bit result
1384213864
// can be optimized as a separate mov of low/high 32-bit.
1384313865
bool use32bitMove = ScanReduceIs64BitType(type) && m_currShader->m_Platform->doScalar64bScan() && m_currShader->m_numberInstance == 1;
1384413866
if (use32bitMove && (step == 2 || step == 4))
1384513867
{
1384613868
CVariable* result32b = m_currShader->GetNewAlias(temp, ISA_TYPE_UD, 0, 2 * step);
13847-
CVariable* dst32b = m_currShader->GetNewAlias(dst, ISA_TYPE_UD, 0, 2 * numLanes(m_currShader->m_SIMDSize));
13869+
CVariable* dst32b = m_currShader->GetNewAlias(dst, ISA_TYPE_UD, 0, 2 * numLanes(simd));
1384813870

13849-
m_encoder->SetSimdSize(m_currShader->m_SIMDSize);
13871+
m_encoder->SetSimdSize(simd);
1385013872
m_encoder->SetSrcRegion(0, 0, step, 2);
1385113873
m_encoder->SetDstRegion(2);
1385213874
m_encoder->Copy(dst32b, result32b);
@@ -13859,7 +13881,7 @@ void EmitPass::emitReductionInterleave(const e_opcode op, const uint64_t identit
1385913881
return;
1386013882
}
1386113883

13862-
m_encoder->SetSimdSize(m_currShader->m_SIMDSize);
13884+
m_encoder->SetSimdSize(simd);
1386313885
m_encoder->SetSrcRegion(0, 0, step, 1);
1386413886
m_encoder->Copy(dst, temp);
1386513887
if (m_currShader->m_numberInstance > 1)
@@ -13871,6 +13893,119 @@ void EmitPass::emitReductionInterleave(const e_opcode op, const uint64_t identit
1387113893
m_encoder->Push();
1387213894
}
1387313895

13896+
void EmitPass::emitReductionClusteredInterleave(const e_opcode op, const uint64_t identityValue, const VISA_Type type,
13897+
const bool negate, const unsigned int clusterSize, const unsigned int interleaveStep, CVariable* const src, CVariable* const dst)
13898+
{
13899+
IGC_ASSERT_MESSAGE(!dst->IsUniform(), "Unsupported: dst must be non-uniform");
13900+
13901+
auto simd = m_currShader->m_SIMDSize;
13902+
auto dataSizeInBytes = CEncoder::GetCISADataTypeSize(type);
13903+
13904+
// If src spans 4 GRFs and cluster spans 2 GRFs (2 clusters total), then WaveClusterInterleave can be expressed
13905+
// as 2 x WaveInterleave, one for each pair of GRFs.
13906+
if (m_currShader->m_numberInstance == 1 && 2 * clusterSize == numLanes(simd) &&
13907+
numLanes(simd) * dataSizeInBytes == 4 * m_currShader->getGRFSize())
13908+
{
13909+
auto interleaveLanes = numLanes(simd) / 2;
13910+
SIMDMode interleaveSIMD = lanesToSIMDMode(interleaveLanes);
13911+
13912+
for (int i = 0; i < 2; ++i)
13913+
{
13914+
CVariable* srcAlias = m_currShader->GetNewAlias(src, type, i * interleaveLanes * dataSizeInBytes, interleaveLanes);
13915+
CVariable* dstAlias = m_currShader->GetNewAlias(dst, type, i * interleaveLanes * dataSizeInBytes, interleaveLanes);
13916+
13917+
emitReductionInterleave(op, type, interleaveSIMD, interleaveStep, true, srcAlias, nullptr, dstAlias);
13918+
}
13919+
13920+
return;
13921+
}
13922+
13923+
// Implementation for each case is custom, with no general solution.
13924+
13925+
if (m_currShader->m_numberInstance == 1 && simd == SIMDMode::SIMD32 && dataSizeInBytes == 4 && clusterSize == 16 && interleaveStep == 2)
13926+
{
13927+
CVariable* temp = m_currShader->GetNewVariable(numLanes(simd), type, EALIGN_GRF, false, "reduceSrc");
13928+
13929+
// Reorder input. Spread every value by two lanes.
13930+
//
13931+
// | 0 | 16 | 1 | 17 | 2 | 18 | ... | 15 | 31 |
13932+
for (int i = 0; i < 2; ++i)
13933+
{
13934+
m_encoder->SetNoMask();
13935+
m_encoder->SetSimdSize(SIMDMode::SIMD16);
13936+
m_encoder->SetSrcRegion(0, 1, 1, 0);
13937+
m_encoder->SetSrcSubReg(0, 16 * i);
13938+
m_encoder->SetDstRegion(2);
13939+
m_encoder->SetDstSubReg(i);
13940+
m_encoder->Copy(temp, src);
13941+
m_encoder->Push();
13942+
}
13943+
13944+
// Reduce.
13945+
temp = ReductionReduceHelper(op, type, SIMDMode::SIMD16, temp);
13946+
temp = ReductionReduceHelper(op, type, SIMDMode::SIMD8, temp);
13947+
temp = ReductionReduceHelper(op, type, SIMDMode::SIMD4, temp);
13948+
13949+
// Propagate output. Repeat each value 8 times.
13950+
// temp: | a | b | c | d |
13951+
// dst: | a | c | a | c | a | c | a | c | ... | b | d | b | d | b | d | b | d |
13952+
for (int i = 0; i < 2; ++i)
13953+
{
13954+
m_encoder->SetNoMask();
13955+
m_encoder->SetSimdSize(SIMDMode::SIMD16);
13956+
m_encoder->SetSrcRegion(0, 1, 8, 0);
13957+
m_encoder->SetSrcSubReg(0, 2 * i);
13958+
m_encoder->SetDstRegion(2);
13959+
m_encoder->SetDstSubReg(i);
13960+
m_encoder->Copy(dst, temp);
13961+
m_encoder->Push();
13962+
}
13963+
}
13964+
else if (m_currShader->m_numberInstance == 1 && simd == SIMDMode::SIMD32 && dataSizeInBytes == 4 && clusterSize == 8 && interleaveStep == 2)
13965+
{
13966+
CVariable* temp = m_currShader->GetNewVariable(numLanes(simd), type, EALIGN_GRF, false, "reduceSrc");
13967+
13968+
// Reorder input. Spread every next two values by 8 lanes:
13969+
//
13970+
// | 0 | 1 | 8 | 9 | 16 | 17 | ... | 14 | 15 | 22 | 23 | 30 | 31 |
13971+
for (int i = 0; i < 4; ++i)
13972+
{
13973+
m_encoder->SetNoMask();
13974+
m_encoder->SetSimdSize(SIMDMode::SIMD8);
13975+
m_encoder->SetSrcRegion(0, 8, 2, 1);
13976+
m_encoder->SetSrcSubReg(0, 2 * i);
13977+
m_encoder->SetDstRegion(1);
13978+
m_encoder->SetDstSubReg(8 * i);
13979+
m_encoder->Copy(temp, src);
13980+
m_encoder->Push();
13981+
}
13982+
13983+
// Reduce.
13984+
temp = ReductionReduceHelper(op, type, SIMDMode::SIMD16, temp);
13985+
temp = ReductionReduceHelper(op, type, SIMDMode::SIMD8, temp);
13986+
13987+
// Propagate output. Repeat each pair of values 4 times.
13988+
//
13989+
// temp: | a | b | c | d | e | f | g | h |
13990+
// dst: | a | b | a | b | a | b | a | b | ... | g | h | g | h | g | h | g | h |
13991+
for (int i = 0; i < 2; ++i)
13992+
{
13993+
m_encoder->SetNoMask();
13994+
m_encoder->SetSimdSize(SIMDMode::SIMD16);
13995+
m_encoder->SetSrcRegion(0, 2, 4, 0);
13996+
m_encoder->SetSrcSubReg(0, i);
13997+
m_encoder->SetDstRegion(2);
13998+
m_encoder->SetDstSubReg(i);
13999+
m_encoder->Copy(dst, temp);
14000+
m_encoder->Push();
14001+
}
14002+
}
14003+
else
14004+
{
14005+
IGC_ASSERT_MESSAGE(false, "Invalid WaveClusteredInterleave.");
14006+
}
14007+
}
14008+
1387414009
// do prefix op across all activate channels
1387514010
void EmitPass::emitPreOrPostFixOp(
1387614011
e_opcode op, uint64_t identityValue, VISA_Type type, bool negateSrc,
@@ -21384,6 +21519,30 @@ void EmitPass::emitWaveInterleave(llvm::GenIntrinsicInst* inst)
2138421519
}
2138521520
}
2138621521

21522+
void EmitPass::emitWaveClusteredInterleave(llvm::GenIntrinsicInst* inst)
21523+
{
21524+
bool disableHelperLanes = int_cast<int>(cast<ConstantInt>(inst->getArgOperand(3))->getSExtValue()) == 2;
21525+
if (disableHelperLanes)
21526+
{
21527+
ForceDMask();
21528+
}
21529+
CVariable* src = GetSymbol(inst->getOperand(0));
21530+
const WaveOps op = static_cast<WaveOps>(cast<llvm::ConstantInt>(inst->getOperand(1))->getZExtValue());
21531+
const unsigned int clusterSize = int_cast<uint32_t>(cast<llvm::ConstantInt>(inst->getOperand(2))->getZExtValue());
21532+
const unsigned int interleaveStep = int_cast<uint32_t>(cast<llvm::ConstantInt>(inst->getOperand(3))->getZExtValue());
21533+
VISA_Type type;
21534+
e_opcode opCode;
21535+
uint64_t identity = 0;
21536+
GetReductionOp(op, inst->getOperand(0)->getType(), identity, opCode, type);
21537+
CVariable* dst = m_destination;
21538+
m_encoder->SetSubSpanDestination(false);
21539+
emitReductionClusteredInterleave(opCode, identity, type, false, clusterSize, interleaveStep, src, dst);
21540+
if (disableHelperLanes)
21541+
{
21542+
ResetVMask();
21543+
}
21544+
}
21545+
2138721546
void EmitPass::emitDP4A(GenIntrinsicInst* GII, const SSource* Sources, const DstModifier& modifier, bool isAccSigned) {
2138821547
GenISAIntrinsic::ID GIID = GII->getIntrinsicID();
2138921548
CVariable* dst = m_destination;

IGC/Compiler/CISACodeGen/EmitVISAPass.hpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,24 @@ class EmitPass : public llvm::FunctionPass
334334
const unsigned int step,
335335
CVariable* const src,
336336
CVariable* const dst);
337+
void emitReductionInterleave(
338+
const e_opcode op,
339+
const VISA_Type type,
340+
const SIMDMode simd,
341+
const unsigned int step,
342+
const bool noMaskBroadcast,
343+
CVariable* const src1,
344+
CVariable* const src2,
345+
CVariable* const dst);
346+
void emitReductionClusteredInterleave(
347+
const e_opcode op,
348+
const uint64_t identityValue,
349+
const VISA_Type type,
350+
const bool negate,
351+
const unsigned int clusterSize,
352+
const unsigned int interleaveStep,
353+
CVariable* const src,
354+
CVariable* const dst);
337355
void emitPreOrPostFixOp(
338356
e_opcode op,
339357
uint64_t identityValue,
@@ -442,6 +460,7 @@ class EmitPass : public llvm::FunctionPass
442460
void emitWaveAll(llvm::GenIntrinsicInst* inst);
443461
void emitWaveClustered(llvm::GenIntrinsicInst* inst);
444462
void emitWaveInterleave(llvm::GenIntrinsicInst* inst);
463+
void emitWaveClusteredInterleave(llvm::GenIntrinsicInst* inst);
445464

446465
// Those three "vector" version shall be combined with
447466
// non-vector version.

IGC/Compiler/CISACodeGen/HalfPromotion.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ void IGC::HalfPromotion::handleGenIntrinsic(llvm::GenIntrinsicInst& I)
114114
if (id == GenISAIntrinsic::GenISA_WaveAll ||
115115
id == GenISAIntrinsic::GenISA_WavePrefix ||
116116
id == GenISAIntrinsic::GenISA_WaveClustered ||
117-
id == GenISAIntrinsic::GenISA_WaveInterleave)
117+
id == GenISAIntrinsic::GenISA_WaveInterleave ||
118+
id == GenISAIntrinsic::GenISA_WaveClusteredInterleave)
118119
{
119120
Module* M = I.getParent()->getParent()->getParent();
120121
llvm::IGCIRBuilder<> builder(&I);

IGC/Compiler/CISACodeGen/PatternMatchPass.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1371,6 +1371,7 @@ namespace IGC
13711371
case GenISAIntrinsic::GenISA_WaveAll:
13721372
case GenISAIntrinsic::GenISA_WaveClustered:
13731373
case GenISAIntrinsic::GenISA_WaveInterleave:
1374+
case GenISAIntrinsic::GenISA_WaveClusteredInterleave:
13741375
case GenISAIntrinsic::GenISA_WavePrefix:
13751376
match = MatchWaveInstruction(*GII);
13761377
break;
@@ -5189,6 +5190,7 @@ namespace IGC
51895190
helperLaneIndex = 3;
51905191
break;
51915192
case GenISAIntrinsic::GenISA_WavePrefix:
5193+
case GenISAIntrinsic::GenISA_WaveClusteredInterleave:
51925194
helperLaneIndex = 4;
51935195
break;
51945196
default:

IGC/Compiler/CISACodeGen/PromoteInt8Type.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,6 +1135,7 @@ void PromoteInt8Type::promoteIntrinsic()
11351135
GII->isGenIntrinsic(GenISAIntrinsic::GenISA_WaveAll) ||
11361136
GII->isGenIntrinsic(GenISAIntrinsic::GenISA_WaveClustered) ||
11371137
GII->isGenIntrinsic(GenISAIntrinsic::GenISA_WaveInterleave) ||
1138+
GII->isGenIntrinsic(GenISAIntrinsic::GenISA_WaveClusteredInterleave) ||
11381139
GII->isGenIntrinsic(GenISAIntrinsic::GenISA_WavePrefix) ||
11391140
GII->isGenIntrinsic(GenISAIntrinsic::GenISA_QuadPrefix))
11401141
{
@@ -1160,6 +1161,7 @@ void PromoteInt8Type::promoteIntrinsic()
11601161
if (gid == GenISAIntrinsic::GenISA_WaveAll ||
11611162
gid == GenISAIntrinsic::GenISA_WaveClustered ||
11621163
gid == GenISAIntrinsic::GenISA_WaveInterleave ||
1164+
gid == GenISAIntrinsic::GenISA_WaveClusteredInterleave ||
11631165
gid == GenISAIntrinsic::GenISA_WavePrefix ||
11641166
gid == GenISAIntrinsic::GenISA_QuadPrefix ||
11651167
gid == GenISAIntrinsic::GenISA_WaveShuffleIndex ||
@@ -1212,8 +1214,10 @@ void PromoteInt8Type::promoteIntrinsic()
12121214
break;
12131215
}
12141216
case GenISAIntrinsic::GenISA_WavePrefix:
1217+
case GenISAIntrinsic::GenISA_WaveClusteredInterleave:
12151218
{
12161219
// prototype: Ty <waveprefix> (Ty, char, bool, bool, int)
1220+
// prototype: Ty <clusteredInterleave> (Ty, char, int, int, int)
12171221
iArgs.push_back(GII->getArgOperand(1));
12181222
iArgs.push_back(GII->getArgOperand(2));
12191223
iArgs.push_back(GII->getArgOperand(3));

IGC/Compiler/CISACodeGen/WIAnalysis.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1423,6 +1423,7 @@ WIAnalysis::WIDependancy WIAnalysisRunner::calculate_dep(const CallInst* inst)
14231423
intrinsic_name == llvm_waveAll ||
14241424
intrinsic_name == llvm_waveClustered ||
14251425
intrinsic_name == llvm_waveInterleave ||
1426+
intrinsic_name == llvm_waveClusteredInterleave ||
14261427
intrinsic_name == llvm_ld_ptr ||
14271428
intrinsic_name == llvm_ldlptr ||
14281429
(IGC_IS_FLAG_DISABLED(DisableUniformTypedAccess) && intrinsic_name == llvm_typed_read) ||
@@ -1733,7 +1734,8 @@ WIAnalysis::WIDependancy WIAnalysisRunner::calculate_dep(const CallInst* inst)
17331734
}
17341735
}
17351736

1736-
if (intrinsic_name == llvm_waveInterleave)
1737+
if (intrinsic_name == llvm_waveInterleave ||
1738+
intrinsic_name == llvm_waveClusteredInterleave)
17371739
{
17381740
return WIAnalysis::RANDOM;
17391741
}

IGC/Compiler/CISACodeGen/helper.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1790,6 +1790,7 @@ namespace IGC
17901790
return (opcode == llvm_waveAll ||
17911791
opcode == llvm_waveClustered ||
17921792
opcode == llvm_waveInterleave ||
1793+
opcode == llvm_waveClusteredInterleave ||
17931794
opcode == llvm_wavePrefix ||
17941795
opcode == llvm_waveShuffleIndex ||
17951796
opcode == llvm_waveBroadcast ||

IGC/Compiler/CISACodeGen/opCode.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ DECLARE_OPCODE(GenISA_WaveBallot, GenISAIntrinsic, llvm_waveBallot, false, false
284284
DECLARE_OPCODE(GenISA_WaveAll, GenISAIntrinsic, llvm_waveAll, false, false, false, false, false, false, false)
285285
DECLARE_OPCODE(GenISA_WaveClustered, GenISAIntrinsic, llvm_waveClustered, false, false, false, false, false, false, false)
286286
DECLARE_OPCODE(GenISA_WaveInterleave, GenISAIntrinsic, llvm_waveInterleave, false, false, false, false, false, false, false)
287+
DECLARE_OPCODE(GenISA_WaveClusteredInterleave, GenISAIntrinsic, llvm_waveClusteredInterleave, false, false, false, false, false, false, false)
287288
DECLARE_OPCODE(GenISA_WavePrefix, GenISAIntrinsic, llvm_wavePrefix, false, false, false, false, false, false, false)
288289
DECLARE_OPCODE(GenISA_QuadPrefix, GenISAIntrinsic, llvm_quadPrefix, false, false, false, false, false, false, false)
289290
DECLARE_OPCODE(GenISA_WaveShuffleIndex, GenISAIntrinsic, llvm_waveShuffleIndex, false, false, false, false, false, false, false)

0 commit comments

Comments
 (0)