Skip to content

Commit cd643e8

Browse files
committed
amdilc: add subgroup ballot for atomic alloc/consume instructions
1 parent 6d146c0 commit cd643e8

File tree

3 files changed

+183
-6
lines changed

3 files changed

+183
-6
lines changed

src/amdilc/amdilc_compiler.c

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2802,9 +2802,31 @@ static void emitAppendBufOp(
28022802
resource = addResource(compiler, &atomicCounterResource);
28032803
}
28042804

2805-
SpvOp op = instr->opcode == IL_OP_APPEND_BUF_ALLOC ? SpvOpAtomicIIncrement
2806-
: SpvOpAtomicIDecrement;
2807-
2805+
bool useSubgroupOps = compiler->kernel->shaderType == IL_SHADER_COMPUTE;
2806+
IlcSpvId electBlockBeginId;
2807+
IlcSpvId electBlockEndId;
2808+
IlcSpvId preElectBlockLabelId;
2809+
IlcSpvId workgroupScopeId;
2810+
IlcSpvId laneCountId, laneIndexId;
2811+
if (useSubgroupOps) {
2812+
preElectBlockLabelId = getTopBlockLabel(compiler);
2813+
electBlockBeginId = ilcSpvAllocId(compiler->module);
2814+
electBlockEndId = ilcSpvAllocId(compiler->module);
2815+
2816+
ilcSpvPutCapability(compiler->module, SpvCapabilityGroupNonUniform);
2817+
ilcSpvPutCapability(compiler->module, SpvCapabilityGroupNonUniformBallot);
2818+
workgroupScopeId = ilcSpvPutConstant(compiler->module, compiler->intId, SpvScopeWorkgroup);
2819+
2820+
IlcSpvId ballotId = ilcSpvPutGroupNonUniformBallot(compiler->module, compiler->uint4Id, workgroupScopeId, ilcSpvPutConstantTrue(compiler->module, compiler->boolId));
2821+
laneCountId = ilcSpvPutGroupNonUniformBallotBitCount(compiler->module, compiler->uintId, workgroupScopeId, SpvGroupOperationReduce, ballotId);
2822+
laneIndexId = ilcSpvPutGroupNonUniformBallotBitCount(compiler->module, compiler->uintId, workgroupScopeId, SpvGroupOperationExclusiveScan, ballotId);
2823+
IlcSpvId electionCondId = ilcSpvPutGroupNonUniformElect(compiler->module, compiler->boolId, workgroupScopeId);
2824+
ilcSpvPutSelectionMerge(compiler->module, electBlockEndId);
2825+
ilcSpvPutBranchConditional(compiler->module, electionCondId, electBlockBeginId, electBlockEndId);
2826+
ilcSpvPutLabel(compiler->module, electBlockBeginId);
2827+
} else {
2828+
laneCountId = ilcSpvPutConstant(compiler->module, compiler->uintId, 1u);
2829+
}
28082830
IlcSpvId ptrTypeId = ilcSpvPutPointerType(compiler->module, SpvStorageClassStorageBuffer,
28092831
compiler->uintId);
28102832
IlcSpvId zeroId = ilcSpvPutConstant(compiler->module, compiler->intId, ZERO_LITERAL);
@@ -2816,10 +2838,27 @@ static void emitAppendBufOp(
28162838
IlcSpvId semanticsId = ilcSpvPutConstant(compiler->module, compiler->intId,
28172839
SpvMemorySemanticsAcquireReleaseMask |
28182840
SpvMemorySemanticsUniformMemoryMask);
2819-
IlcSpvId readId = ilcSpvPutAtomicOp(compiler->module, op, compiler->uintId, ptrId,
2820-
scopeId, semanticsId, 0);
2821-
IlcSpvId resId = emitVectorGrow(compiler, readId, compiler->uintId, 1);
2841+
IlcSpvId readId;
2842+
SpvOp op = instr->opcode == IL_OP_APPEND_BUF_ALLOC ? SpvOpAtomicIAdd : SpvOpAtomicISub;
2843+
2844+
readId = ilcSpvPutAtomicOp(compiler->module, op, compiler->uintId, ptrId,
2845+
scopeId, semanticsId, laneCountId);
28222846

2847+
if (useSubgroupOps) {
2848+
ilcSpvPutBranch(compiler->module, electBlockEndId);
2849+
ilcSpvPutLabel(compiler->module, electBlockEndId);
2850+
2851+
IlcSpvId constUndefId = ilcSpvPutConstantUndef(compiler->module, compiler->uintId);
2852+
IlcSpvId phiLabels[4] = {
2853+
readId, electBlockBeginId,
2854+
constUndefId, preElectBlockLabelId,
2855+
};
2856+
readId = ilcSpvPutPhi(compiler->module, compiler->uintId, 4, phiLabels);
2857+
readId = ilcSpvPutGroupNonUniformBroadcastFirst(compiler->module, compiler->uintId, workgroupScopeId, readId);
2858+
readId = ilcSpvPutOp2(compiler->module, instr->opcode == IL_OP_APPEND_BUF_ALLOC ? SpvOpIAdd : SpvOpISub, compiler->uintId, readId, laneIndexId);
2859+
}
2860+
2861+
IlcSpvId resId = emitVectorGrow(compiler, readId, compiler->uintId, 1);
28232862
storeDestination(compiler, dst, resId, compiler->uint4Id);
28242863
}
28252864

src/amdilc/amdilc_spirv.c

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,13 @@ IlcSpvId ilcSpvPutConstant(
463463
return putConstant(module, SpvOpConstant, resultTypeId, 1, &literal);
464464
}
465465

466+
IlcSpvId ilcSpvPutConstantUndef(
467+
IlcSpvModule* module,
468+
IlcSpvId resultTypeId)
469+
{
470+
return putConstant(module, SpvOpUndef, resultTypeId, 0, NULL);
471+
}
472+
466473
IlcSpvId ilcSpvPutConstantComposite(
467474
IlcSpvModule* module,
468475
IlcSpvId resultTypeId,
@@ -473,6 +480,13 @@ IlcSpvId ilcSpvPutConstantComposite(
473480
consistuentCount, consistuents);
474481
}
475482

483+
IlcSpvId ilcSpvPutConstantTrue(
484+
IlcSpvModule* module,
485+
IlcSpvId resultTypeId)
486+
{
487+
return putConstant(module, SpvOpConstantTrue, resultTypeId, 0, NULL);
488+
}
489+
476490
void ilcSpvPutFunction(
477491
IlcSpvModule* module,
478492
IlcSpvId resultType,
@@ -928,6 +942,92 @@ IlcSpvId ilcSpvPutBitcast(
928942
return id;
929943
}
930944

945+
IlcSpvId ilcSpvPutGroupNonUniformBallot(
946+
IlcSpvModule* module,
947+
IlcSpvId resultTypeId,
948+
IlcSpvId scopeId,
949+
IlcSpvId predicateId)
950+
{
951+
IlcSpvBuffer* buffer = &module->buffer[ID_CODE];
952+
953+
IlcSpvId id = ilcSpvAllocId(module);
954+
putInstr(buffer, SpvOpGroupNonUniformBallot, 5);
955+
putWord(buffer, resultTypeId);
956+
putWord(buffer, id);
957+
putWord(buffer, scopeId);
958+
putWord(buffer, predicateId);
959+
return id;
960+
}
961+
962+
IlcSpvId ilcSpvPutGroupNonUniformBallotBitCount(
963+
IlcSpvModule* module,
964+
IlcSpvId resultTypeId,
965+
IlcSpvId scopeId,
966+
SpvGroupOperation groupOperation,
967+
IlcSpvId valueId)
968+
{
969+
IlcSpvBuffer* buffer = &module->buffer[ID_CODE];
970+
971+
IlcSpvId id = ilcSpvAllocId(module);
972+
putInstr(buffer, SpvOpGroupNonUniformBallotBitCount, 6);
973+
putWord(buffer, resultTypeId);
974+
putWord(buffer, id);
975+
putWord(buffer, scopeId);
976+
putWord(buffer, groupOperation);
977+
putWord(buffer, valueId);
978+
return id;
979+
}
980+
981+
IlcSpvId ilcSpvPutGroupNonUniformElect(
982+
IlcSpvModule* module,
983+
IlcSpvId resultTypeId,
984+
IlcSpvId scopeId)
985+
{
986+
IlcSpvBuffer* buffer = &module->buffer[ID_CODE];
987+
988+
IlcSpvId id = ilcSpvAllocId(module);
989+
putInstr(buffer, SpvOpGroupNonUniformElect, 4);
990+
putWord(buffer, resultTypeId);
991+
putWord(buffer, id);
992+
putWord(buffer, scopeId);
993+
return id;
994+
}
995+
996+
IlcSpvId ilcSpvPutGroupNonUniformBroadcastFirst(
997+
IlcSpvModule* module,
998+
IlcSpvId resultTypeId,
999+
IlcSpvId scopeId,
1000+
IlcSpvId valueId)
1001+
{
1002+
IlcSpvBuffer* buffer = &module->buffer[ID_CODE];
1003+
1004+
IlcSpvId id = ilcSpvAllocId(module);
1005+
putInstr(buffer, SpvOpGroupNonUniformBroadcastFirst, 5);
1006+
putWord(buffer, resultTypeId);
1007+
putWord(buffer, id);
1008+
putWord(buffer, scopeId);
1009+
putWord(buffer, valueId);
1010+
return id;
1011+
}
1012+
1013+
IlcSpvId ilcSpvPutPhi(
1014+
IlcSpvModule* module,
1015+
IlcSpvId resultTypeId,
1016+
unsigned argCount,
1017+
const IlcSpvId* args)
1018+
{
1019+
IlcSpvBuffer* buffer = &module->buffer[ID_CODE];
1020+
1021+
IlcSpvId id = ilcSpvAllocId(module);
1022+
putInstr(buffer, SpvOpPhi, 3 + argCount);
1023+
putWord(buffer, resultTypeId);
1024+
putWord(buffer, id);
1025+
for (unsigned i = 0; i < argCount; ++i) {
1026+
putWord(buffer, args[i]);
1027+
}
1028+
return id;
1029+
}
1030+
9311031
IlcSpvId ilcSpvPutSelect(
9321032
IlcSpvModule* module,
9331033
IlcSpvId resultTypeId,

src/amdilc/amdilc_spirv.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,12 +149,20 @@ IlcSpvId ilcSpvPutConstant(
149149
IlcSpvId resultTypeId,
150150
IlcSpvWord literal);
151151

152+
IlcSpvId ilcSpvPutConstantUndef(
153+
IlcSpvModule* module,
154+
IlcSpvId resultTypeId);
155+
152156
IlcSpvId ilcSpvPutConstantComposite(
153157
IlcSpvModule* module,
154158
IlcSpvId resultTypeId,
155159
unsigned consistuentCount,
156160
const IlcSpvId* consistuents);
157161

162+
IlcSpvId ilcSpvPutConstantTrue(
163+
IlcSpvModule* module,
164+
IlcSpvId resultTypeId);
165+
158166
void ilcSpvPutFunction(
159167
IlcSpvModule* module,
160168
IlcSpvId resultType,
@@ -323,6 +331,36 @@ IlcSpvId ilcSpvPutBitcast(
323331
IlcSpvId resultTypeId,
324332
IlcSpvId operandId);
325333

334+
IlcSpvId ilcSpvPutGroupNonUniformBallot(
335+
IlcSpvModule* module,
336+
IlcSpvId resultTypeId,
337+
IlcSpvId scopeId,
338+
IlcSpvId predicateId);
339+
340+
IlcSpvId ilcSpvPutGroupNonUniformBallotBitCount(
341+
IlcSpvModule* module,
342+
IlcSpvId resultTypeId,
343+
IlcSpvId scopeId,
344+
SpvGroupOperation groupOperation,
345+
IlcSpvId valueId);
346+
347+
IlcSpvId ilcSpvPutGroupNonUniformElect(
348+
IlcSpvModule* module,
349+
IlcSpvId resultTypeId,
350+
IlcSpvId scopeId);
351+
352+
IlcSpvId ilcSpvPutGroupNonUniformBroadcastFirst(
353+
IlcSpvModule* module,
354+
IlcSpvId resultTypeId,
355+
IlcSpvId scopeId,
356+
IlcSpvId valueId);
357+
358+
IlcSpvId ilcSpvPutPhi(
359+
IlcSpvModule* module,
360+
IlcSpvId resultTypeId,
361+
unsigned argCount,
362+
const IlcSpvId* args);
363+
326364
IlcSpvId ilcSpvPutSelect(
327365
IlcSpvModule* module,
328366
IlcSpvId resultTypeId,

0 commit comments

Comments
 (0)