Skip to content

Commit 9e7994c

Browse files
committed
amdilc: added support for atomic operations on raw and structured UAV, added support for LDS atomic instructions
Implemented different address calculation for raw and structured UAVs according to specs, also added translation for LDS (both structured and raw) atomic instructions that are being used in BF4.
1 parent d12a61d commit 9e7994c

File tree

1 file changed

+127
-23
lines changed

1 file changed

+127
-23
lines changed

src/amdilc/amdilc_compiler.c

Lines changed: 127 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ typedef struct {
4949
uint32_t ilId;
5050
uint8_t ilType;
5151
IlcSpvId strideId;
52+
bool structured;
5253
} IlcResource;
5354

5455
typedef struct {
@@ -1149,6 +1150,7 @@ static void emitResource(
11491150
.ilId = id,
11501151
.ilType = type,
11511152
.strideId = 0,
1153+
.structured = false,
11521154
};
11531155

11541156
addResource(compiler, &resource);
@@ -1216,6 +1218,7 @@ static void emitTypedUav(
12161218
.ilId = id,
12171219
.ilType = type,
12181220
.strideId = 0,
1221+
.structured = false,
12191222
};
12201223

12211224
addResource(compiler, &resource);
@@ -1254,6 +1257,7 @@ static void emitUav(
12541257
.ilType = IL_USAGE_PIXTEX_UNKNOWN,
12551258
.strideId = ilcSpvPutConstant(compiler->module, compiler->intId,
12561259
isStructured ? instr->extras[0] : 4),
1260+
.structured = isStructured,
12571261
};
12581262

12591263
addResource(compiler, &resource);
@@ -1337,6 +1341,7 @@ static void emitSrv(
13371341
.ilId = id,
13381342
.ilType = IL_USAGE_PIXTEX_UNKNOWN,
13391343
.strideId = strideId,
1344+
.structured = isStructured,
13401345
};
13411346

13421347
addResource(compiler, &resource);
@@ -1364,6 +1369,7 @@ static void emitLds(
13641369
.ilId = id,
13651370
.ilType = IL_USAGE_PIXTEX_UNKNOWN,
13661371
.strideId = ilcSpvPutConstant(compiler->module, compiler->intId, 4),
1372+
.structured = false,
13671373
};
13681374

13691375
addResource(compiler, &resource);
@@ -1392,6 +1398,7 @@ static void emitStructuredLds(
13921398
.ilId = id,
13931399
.ilType = IL_USAGE_PIXTEX_UNKNOWN,
13941400
.strideId = ilcSpvPutConstant(compiler->module, compiler->intId, stride),
1401+
.structured = true,
13951402
};
13961403

13971404
addResource(compiler, &resource);
@@ -2461,24 +2468,27 @@ static void emitLdsStoreVec(
24612468
IlcSpvId wordAddrId = ilcSpvPutOp2(compiler->module, SpvOpSDiv, compiler->intId,
24622469
byteAddrId, fourId);
24632470

2464-
IlcSpvId oneId = ilcSpvPutConstant(compiler->module, compiler->intId, 1);
24652471
IlcSpvId ptrTypeId = ilcSpvPutPointerType(compiler->module, SpvStorageClassWorkgroup,
24662472
resource->texelTypeId);
24672473

24682474
// Write up to four components based on the destination mask
24692475
for (unsigned i = 0; i < 4; i++) {
24702476
if (dst->component[i] == IL_MODCOMP_NOWRITE) {
2471-
break;
2477+
continue;
24722478
}
24732479

2480+
IlcSpvId addrId;
24742481
if (i > 0) {
24752482
// Increment address
2476-
wordAddrId = ilcSpvPutOp2(compiler->module, SpvOpIAdd, compiler->intId,
2477-
wordAddrId, oneId);
2483+
IlcSpvId offsetId = ilcSpvPutConstant(compiler->module, compiler->intId, i);
2484+
addrId = ilcSpvPutOp2(compiler->module, SpvOpIAdd, compiler->intId,
2485+
wordAddrId, offsetId);
2486+
} else {
2487+
addrId = wordAddrId;
24782488
}
24792489

24802490
IlcSpvId ptrId = ilcSpvPutAccessChain(compiler->module, ptrTypeId, resource->id,
2481-
1, &wordAddrId);
2491+
1, &addrId);
24822492
IlcSpvId componentId = emitVectorTrim(compiler, dataId, compiler->uint4Id, i, 1);
24832493
ilcSpvPutStore(compiler->module, ptrId, componentId);
24842494
}
@@ -2555,28 +2565,89 @@ static void emitStructUavStore(
25552565
IlcSpvId byteAddrId = ilcSpvPutOp2(compiler->module, SpvOpIAdd, compiler->intId, baseId, offsetId);
25562566
IlcSpvId wordAddrId = ilcSpvPutOp2(compiler->module, SpvOpSDiv, compiler->intId, byteAddrId, ilcSpvPutConstant(compiler->module, compiler->intId, 4));
25572567

2558-
IlcSpvId oneId = ilcSpvPutConstant(compiler->module, compiler->intId, 1);
2568+
IlcSpvId zeroId = ilcSpvPutConstant(compiler->module, compiler->intId, 0);
25592569
IlcSpvId ptrTypeId = ilcSpvPutPointerType(compiler->module, SpvStorageClassStorageBuffer,
25602570
resource->texelTypeId);
25612571
// Write up to four components based on the destination mask
25622572
for (unsigned i = 0; i < 4; i++) {
25632573
if (dst->component[i] == IL_MODCOMP_NOWRITE) {
2564-
break;
2574+
continue;
25652575
}
25662576

2577+
IlcSpvId addrId;
25672578
if (i > 0) {
2568-
// Increment address
2569-
wordAddrId = ilcSpvPutOp2(compiler->module, SpvOpIAdd, compiler->intId,
2570-
wordAddrId, oneId);
2579+
// calculate address
2580+
IlcSpvId offsetId = ilcSpvPutConstant(compiler->module, compiler->intId, i);
2581+
addrId = ilcSpvPutOp2(compiler->module, SpvOpIAdd, compiler->intId,
2582+
wordAddrId, offsetId);
2583+
} else {
2584+
addrId = wordAddrId;
25712585
}
25722586

2587+
IlcSpvId indexIds[] = { zeroId, addrId };
25732588
IlcSpvId ptrId = ilcSpvPutAccessChain(compiler->module, ptrTypeId, resource->id,
2574-
1, &wordAddrId);
2589+
2, indexIds);
25752590
IlcSpvId componentId = emitVectorTrim(compiler, elementId, elementTypeId, i, 1);
25762591
ilcSpvPutStore(compiler->module, ptrId, componentId);
25772592
}
25782593
}
25792594

2595+
static void emitLdsAtomicOp(
2596+
IlcCompiler* compiler,
2597+
const Instruction* instr)
2598+
{
2599+
uint8_t ilResourceId = GET_BITS(instr->control, 0, 4);
2600+
2601+
const IlcResource* resource = findResource(compiler, RES_TYPE_LDS, ilResourceId);
2602+
2603+
if (resource == NULL) {
2604+
LOGE("resource %d not found\n", ilResourceId);
2605+
return;
2606+
}
2607+
2608+
IlcSpvId pointerTypeId = ilcSpvPutPointerType(compiler->module, SpvStorageClassWorkgroup,
2609+
resource->texelTypeId);
2610+
IlcSpvId addressId = loadSource(compiler, &instr->srcs[0], COMP_MASK_XYZW, compiler->int4Id);
2611+
IlcSpvId byteAddrId;
2612+
if (resource->structured) {
2613+
IlcSpvId indexId = emitVectorTrim(compiler, addressId, compiler->int4Id, COMP_INDEX_X, 1);
2614+
IlcSpvId offsetId = emitVectorTrim(compiler, addressId, compiler->int4Id, COMP_INDEX_Y, 1);
2615+
// addr = (index * stride + offset) / 4
2616+
IlcSpvId baseId = ilcSpvPutOp2(compiler->module, SpvOpIMul, compiler->intId,
2617+
indexId, resource->strideId);
2618+
byteAddrId = ilcSpvPutOp2(compiler->module, SpvOpIAdd, compiler->intId,
2619+
baseId, offsetId);
2620+
} else {
2621+
byteAddrId = emitVectorTrim(compiler, addressId, compiler->int4Id, COMP_INDEX_X, 1);
2622+
}
2623+
IlcSpvId wordAddrId = ilcSpvPutOp2(compiler->module, SpvOpSDiv, compiler->intId,
2624+
byteAddrId, ilcSpvPutConstant(compiler->module, compiler->intId, 4));
2625+
IlcSpvId bufferPtrId = ilcSpvPutAccessChain(compiler->module, pointerTypeId, resource->id,
2626+
1, &wordAddrId);
2627+
IlcSpvId readId = 0;
2628+
IlcSpvId vecTypeId = ilcSpvPutVectorType(compiler->module, resource->texelTypeId, 4);
2629+
IlcSpvId scopeId = ilcSpvPutConstant(compiler->module, compiler->intId, SpvScopeWorkgroup);
2630+
IlcSpvId semanticsId = ilcSpvPutConstant(compiler->module, compiler->intId,
2631+
SpvMemorySemanticsAcquireReleaseMask |
2632+
SpvMemorySemanticsSubgroupMemoryMask);
2633+
IlcSpvId src1Id = loadSource(compiler, &instr->srcs[1], COMP_MASK_XYZW, vecTypeId);
2634+
IlcSpvId valueId = emitVectorTrim(compiler, src1Id, vecTypeId, COMP_INDEX_X, 1);
2635+
2636+
if (instr->opcode == IL_OP_LDS_ADD || instr->opcode == IL_OP_LDS_READ_ADD) {
2637+
readId = ilcSpvPutAtomicOp(compiler->module, SpvOpAtomicIAdd, resource->texelTypeId,
2638+
bufferPtrId, scopeId, semanticsId, valueId);
2639+
} else if (instr->opcode == IL_OP_LDS_UMAX || instr->opcode == IL_OP_LDS_READ_UMAX) {
2640+
readId = ilcSpvPutAtomicOp(compiler->module, SpvOpAtomicUMax, resource->texelTypeId,
2641+
bufferPtrId, scopeId, semanticsId, valueId);
2642+
} else {
2643+
assert(false);
2644+
}
2645+
2646+
if (instr->dstCount > 0) {
2647+
IlcSpvId resId = emitVectorGrow(compiler, readId, resource->texelTypeId, 1);
2648+
storeDestination(compiler, &instr->dsts[0], resId, vecTypeId);
2649+
}
2650+
}
25802651

25812652
static void emitUavAtomicOp(
25822653
IlcCompiler* compiler,
@@ -2591,21 +2662,48 @@ static void emitUavAtomicOp(
25912662
return;
25922663
}
25932664

2594-
IlcSpvId vecTypeId = ilcSpvPutVectorType(compiler->module, resource->texelTypeId, 4);
2595-
IlcSpvId pointerTypeId = ilcSpvPutPointerType(compiler->module, SpvStorageClassImage,
2596-
resource->texelTypeId);
2597-
IlcSpvId addressId = loadSource(compiler, &instr->srcs[0], COMP_MASK_XYZW, compiler->int4Id);
2598-
IlcSpvId trimAddressId = emitVectorTrim(compiler, addressId, compiler->int4Id, COMP_INDEX_X,
2599-
getResourceDimensionCount(resource->ilType));
2600-
IlcSpvId zeroId = ilcSpvPutConstant(compiler->module, compiler->intId, ZERO_LITERAL);
2601-
IlcSpvId texelPtrId = ilcSpvPutImageTexelPointer(compiler->module, pointerTypeId, resource->id,
2602-
trimAddressId, zeroId);
2665+
IlcSpvId texelPtrId ;
26032666

2667+
IlcSpvId semanticsId;
2668+
if (resource->strideId == 0) {
2669+
IlcSpvId pointerTypeId = ilcSpvPutPointerType(compiler->module, SpvStorageClassImage,
2670+
resource->texelTypeId);
2671+
IlcSpvId addressId = loadSource(compiler, &instr->srcs[0], COMP_MASK_XYZW, compiler->int4Id);
2672+
IlcSpvId trimAddressId = emitVectorTrim(compiler, addressId, compiler->int4Id, COMP_INDEX_X,
2673+
getResourceDimensionCount(resource->ilType));
2674+
IlcSpvId zeroId = ilcSpvPutConstant(compiler->module, compiler->intId, ZERO_LITERAL);
2675+
texelPtrId = ilcSpvPutImageTexelPointer(compiler->module, pointerTypeId, resource->id,
2676+
trimAddressId, zeroId);
2677+
semanticsId = ilcSpvPutConstant(compiler->module, compiler->intId,
2678+
SpvMemorySemanticsAcquireReleaseMask |
2679+
SpvMemorySemanticsImageMemoryMask);
2680+
} else {
2681+
IlcSpvId pointerTypeId = ilcSpvPutPointerType(compiler->module, SpvStorageClassStorageBuffer,
2682+
resource->texelTypeId);
2683+
IlcSpvId addressId = loadSource(compiler, &instr->srcs[0], COMP_MASK_XYZW, compiler->int4Id);
2684+
IlcSpvId byteAddrId;
2685+
if (resource->structured) {
2686+
IlcSpvId indexId = emitVectorTrim(compiler, addressId, compiler->int4Id, COMP_INDEX_X, 1);
2687+
IlcSpvId offsetId = emitVectorTrim(compiler, addressId, compiler->int4Id, COMP_INDEX_Y, 1);
2688+
// addr = (index * stride + offset) / 4
2689+
IlcSpvId baseId = ilcSpvPutOp2(compiler->module, SpvOpIMul, compiler->intId, indexId, resource->strideId);
2690+
byteAddrId = ilcSpvPutOp2(compiler->module, SpvOpIAdd, compiler->intId, baseId, offsetId);
2691+
} else {
2692+
byteAddrId = emitVectorTrim(compiler, addressId, compiler->int4Id, COMP_INDEX_X, 1);
2693+
}
2694+
IlcSpvId zeroId = ilcSpvPutConstant(compiler->module, compiler->intId, ZERO_LITERAL);
2695+
IlcSpvId wordAddrId = ilcSpvPutOp2(compiler->module, SpvOpSDiv, compiler->intId,
2696+
byteAddrId, ilcSpvPutConstant(compiler->module, compiler->intId, 4));
2697+
const IlcSpvId indexIds[] = { zeroId, wordAddrId };
2698+
texelPtrId = ilcSpvPutAccessChain(compiler->module, pointerTypeId, resource->id,
2699+
2, indexIds);
2700+
semanticsId = ilcSpvPutConstant(compiler->module, compiler->intId,
2701+
SpvMemorySemanticsAcquireReleaseMask |
2702+
SpvMemorySemanticsUniformMemoryMask);
2703+
}
26042704
IlcSpvId readId = 0;
2705+
IlcSpvId vecTypeId = ilcSpvPutVectorType(compiler->module, resource->texelTypeId, 4);
26052706
IlcSpvId scopeId = ilcSpvPutConstant(compiler->module, compiler->intId, SpvScopeDevice);
2606-
IlcSpvId semanticsId = ilcSpvPutConstant(compiler->module, compiler->intId,
2607-
SpvMemorySemanticsAcquireReleaseMask |
2608-
SpvMemorySemanticsImageMemoryMask);
26092707
IlcSpvId src1Id = loadSource(compiler, &instr->srcs[1], COMP_MASK_XYZW, vecTypeId);
26102708
IlcSpvId valueId = emitVectorTrim(compiler, src1Id, vecTypeId, COMP_INDEX_X, 1);
26112709

@@ -3031,6 +3129,12 @@ static void emitInstr(
30313129
case IL_OP_UAV_READ_UMAX:
30323130
emitUavAtomicOp(compiler, instr);
30333131
break;
3132+
case IL_OP_LDS_ADD:
3133+
case IL_OP_LDS_READ_ADD:
3134+
case IL_OP_LDS_UMAX:
3135+
case IL_OP_LDS_READ_UMAX:
3136+
emitLdsAtomicOp(compiler, instr);
3137+
break;
30343138
case IL_OP_DCL_RAW_SRV:
30353139
case IL_OP_DCL_STRUCT_SRV:
30363140
emitSrv(compiler, instr);

0 commit comments

Comments
 (0)