Skip to content

Commit 2c913d6

Browse files
[SPIR-V] Fix push_constant and shader_record_ext to work with LowerTypeVisitor (microsoft#6011)
Since `ConstantBuffer` is now lowered in `LowerTypeVisitor`, lower `push_constant` and `shader_record_*` with `ConstantBuffer` types there as well. Fixes microsoft#5808.
1 parent 773fed3 commit 2c913d6

12 files changed

+164
-105
lines changed

tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

Lines changed: 57 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1307,7 +1307,7 @@ SpirvVariable *DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout(
13071307
const bool forShaderRecordNV =
13081308
usageKind == ContextUsageKind::ShaderRecordBufferNV;
13091309
const bool forShaderRecordEXT =
1310-
usageKind == ContextUsageKind::ShaderRecordBufferEXT;
1310+
usageKind == ContextUsageKind::ShaderRecordBufferKHR;
13111311

13121312
const auto &declGroup = collectDeclsInDeclContext(decl);
13131313

@@ -1357,9 +1357,6 @@ SpirvVariable *DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout(
13571357
}
13581358
}
13591359

1360-
// Register the <type-id> for this decl
1361-
ctBufferPCTypes[decl] = resultType;
1362-
13631360
const auto sc = forPC ? spv::StorageClass::PushConstant
13641361
: forShaderRecordNV ? spv::StorageClass::ShaderRecordBufferNV
13651362
: forShaderRecordEXT
@@ -1451,18 +1448,30 @@ SpirvVariable *DeclResultIdMapper::createPushConstant(const VarDecl *decl) {
14511448
const QualType type = decl->getType();
14521449
const auto *recordType = type->getAs<RecordType>();
14531450

1451+
SpirvVariable *var = nullptr;
1452+
14541453
if (isConstantBuffer(type)) {
1455-
// Get the templated type for ConstantBuffer.
1456-
recordType = hlsl::GetHLSLResourceResultType(type)->getAs<RecordType>();
1457-
}
1454+
// Constant buffers already have Block decoration. The variable will need
1455+
// the PushConstant storage class.
14581456

1459-
assert(recordType);
1457+
// Create the variable for the whole struct / struct array.
1458+
// The fields may be 'precise', but the structure itself is not.
1459+
var = spvBuilder.addModuleVar(type, spv::StorageClass::PushConstant,
1460+
/*isPrecise*/ false,
1461+
/*isNoInterp*/ false, decl->getName());
14601462

1461-
const std::string structName =
1462-
"type.PushConstant." + recordType->getDecl()->getName().str();
1463-
SpirvVariable *var = createStructOrStructArrayVarOfExplicitLayout(
1464-
recordType->getDecl(), /*arraySize*/ 0, ContextUsageKind::PushConstant,
1465-
structName, decl->getName());
1463+
const SpirvLayoutRule layoutRule = spirvOptions.sBufferLayoutRule;
1464+
1465+
var->setHlslUserType("");
1466+
var->setLayoutRule(layoutRule);
1467+
} else {
1468+
assert(recordType);
1469+
const std::string structName =
1470+
"type.PushConstant." + recordType->getDecl()->getName().str();
1471+
var = createStructOrStructArrayVarOfExplicitLayout(
1472+
recordType->getDecl(), /*arraySize*/ 0, ContextUsageKind::PushConstant,
1473+
structName, decl->getName());
1474+
}
14661475

14671476
// Register the VarDecl
14681477
astDecls[decl] = createDeclSpirvInfo(var);
@@ -1476,22 +1485,44 @@ SpirvVariable *DeclResultIdMapper::createPushConstant(const VarDecl *decl) {
14761485
SpirvVariable *
14771486
DeclResultIdMapper::createShaderRecordBuffer(const VarDecl *decl,
14781487
ContextUsageKind kind) {
1488+
const QualType type = decl->getType();
14791489
const auto *recordType =
1480-
hlsl::GetHLSLResourceResultType(decl->getType())->getAs<RecordType>();
1490+
hlsl::GetHLSLResourceResultType(type)->getAs<RecordType>();
14811491
assert(recordType);
14821492

1483-
assert(kind == ContextUsageKind::ShaderRecordBufferEXT ||
1493+
assert(kind == ContextUsageKind::ShaderRecordBufferKHR ||
14841494
kind == ContextUsageKind::ShaderRecordBufferNV);
14851495

1486-
const auto typeName = kind == ContextUsageKind::ShaderRecordBufferEXT
1487-
? "type.ShaderRecordBufferEXT."
1488-
: "type.ShaderRecordBufferNV.";
1496+
SpirvVariable *var = nullptr;
1497+
if (isConstantBuffer(type)) {
1498+
// Constant buffers already have Block decoration. The variable will need
1499+
// the appropriate storage class.
1500+
1501+
const auto sc = kind == ContextUsageKind::ShaderRecordBufferNV
1502+
? spv::StorageClass::ShaderRecordBufferNV
1503+
: spv::StorageClass::ShaderRecordBufferKHR;
1504+
1505+
// Create the variable for the whole struct / struct array.
1506+
// The fields may be 'precise', but the structure itself is not.
1507+
var = spvBuilder.addModuleVar(type, sc,
1508+
/*isPrecise*/ false,
1509+
/*isNoInterp*/ false, decl->getName());
14891510

1490-
const std::string structName =
1491-
typeName + recordType->getDecl()->getName().str();
1492-
SpirvVariable *var = createStructOrStructArrayVarOfExplicitLayout(
1493-
recordType->getDecl(), /*arraySize*/ 0, kind, structName,
1494-
decl->getName());
1511+
const SpirvLayoutRule layoutRule = spirvOptions.sBufferLayoutRule;
1512+
1513+
var->setHlslUserType("");
1514+
var->setLayoutRule(layoutRule);
1515+
} else {
1516+
const auto typeName = kind == ContextUsageKind::ShaderRecordBufferKHR
1517+
? "type.ShaderRecordBufferKHR."
1518+
: "type.ShaderRecordBufferNV.";
1519+
1520+
const std::string structName =
1521+
typeName + recordType->getDecl()->getName().str();
1522+
var = createStructOrStructArrayVarOfExplicitLayout(
1523+
recordType->getDecl(), /*arraySize*/ 0, kind, structName,
1524+
decl->getName());
1525+
}
14951526

14961527
// Register the VarDecl
14971528
astDecls[decl] = createDeclSpirvInfo(var);
@@ -1505,11 +1536,11 @@ DeclResultIdMapper::createShaderRecordBuffer(const VarDecl *decl,
15051536
SpirvVariable *
15061537
DeclResultIdMapper::createShaderRecordBuffer(const HLSLBufferDecl *decl,
15071538
ContextUsageKind kind) {
1508-
assert(kind == ContextUsageKind::ShaderRecordBufferEXT ||
1539+
assert(kind == ContextUsageKind::ShaderRecordBufferKHR ||
15091540
kind == ContextUsageKind::ShaderRecordBufferNV);
15101541

1511-
const auto typeName = kind == ContextUsageKind::ShaderRecordBufferEXT
1512-
? "type.ShaderRecordBufferEXT."
1542+
const auto typeName = kind == ContextUsageKind::ShaderRecordBufferKHR
1543+
? "type.ShaderRecordBufferKHR."
15131544
: "type.ShaderRecordBufferNV.";
15141545

15151546
const std::string structName = typeName + decl->getName().str();
@@ -1769,13 +1800,6 @@ void DeclResultIdMapper::createFieldCounterVars(
17691800
}
17701801
}
17711802

1772-
const SpirvType *
1773-
DeclResultIdMapper::getCTBufferPushConstantType(const DeclContext *decl) {
1774-
const auto found = ctBufferPCTypes.find(decl);
1775-
assert(found != ctBufferPCTypes.end());
1776-
return found->second;
1777-
}
1778-
17791803
std::vector<SpirvVariable *>
17801804
DeclResultIdMapper::collectStageVars(SpirvFunction *entryPoint) const {
17811805
std::vector<SpirvVariable *> vars;

tools/clang/lib/SPIRV/DeclResultIdMapper.h

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ class DeclResultIdMapper {
352352
PushConstant,
353353
Globals,
354354
ShaderRecordBufferNV,
355-
ShaderRecordBufferEXT
355+
ShaderRecordBufferKHR
356356
};
357357

358358
/// Raytracing specific functions
@@ -432,19 +432,6 @@ class DeclResultIdMapper {
432432
/// buffers. Returns nullptr if it does not.
433433
const CounterVarFields *getCounterVarFields(const DeclaratorDecl *decl);
434434

435-
/// \brief Returns the <type-id> for the given cbuffer, tbuffer,
436-
/// ConstantBuffer, TextureBuffer, or push constant block.
437-
///
438-
/// Note: we need this method because constant/texture buffers and push
439-
/// constant blocks are all represented as normal struct types upon which
440-
/// they are parameterized. That is different from structured buffers,
441-
/// for which we can tell they are not normal structs by investigating
442-
/// the name. But for constant/texture buffers and push constant blocks,
443-
/// we need to have the additional Block/BufferBlock decoration to keep
444-
/// type consistent. Normal translation path for structs via TypeTranslator
445-
/// won't attach Block/BufferBlock decoration.
446-
const SpirvType *getCTBufferPushConstantType(const DeclContext *decl);
447-
448435
/// \brief Returns all defined stage (builtin/input/ouput) variables for the
449436
/// entry point function entryPoint in this mapper.
450437
std::vector<SpirvVariable *>
@@ -825,10 +812,6 @@ class DeclResultIdMapper {
825812
/// until a Increment/DecrementCounter method is called on it.
826813
llvm::DenseMap<const DeclaratorDecl *, SpirvInstruction *> declRWSBuffers;
827814

828-
/// Mapping from cbuffer/tbuffer/ConstantBuffer/TextureBufer/push-constant
829-
/// to the SPIR-V type.
830-
llvm::DenseMap<const DeclContext *, const SpirvType *> ctBufferPCTypes;
831-
832815
/// The execution mode to use for rasterizer ordered views. Should be set to
833816
/// PixelInterlockOrderedEXT (default), SampleInterlockOrderedEXT, or
834817
/// ShadingRateInterlockOrderedEXT. This will be set based on which semantics

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1677,7 +1677,7 @@ void SpirvEmitter::doHLSLBufferDecl(const HLSLBufferDecl *bufferDecl) {
16771677
} else if (bufferDecl->hasAttr<VKShaderRecordEXTAttr>()) {
16781678
(void)declIdMapper.createShaderRecordBuffer(
16791679
bufferDecl,
1680-
DeclResultIdMapper::ContextUsageKind::ShaderRecordBufferEXT);
1680+
DeclResultIdMapper::ContextUsageKind::ShaderRecordBufferKHR);
16811681
} else {
16821682
(void)declIdMapper.createCTBuffer(bufferDecl);
16831683
}
@@ -1793,7 +1793,7 @@ void SpirvEmitter::doVarDecl(const VarDecl *decl) {
17931793

17941794
if (decl->hasAttr<VKShaderRecordEXTAttr>()) {
17951795
(void)declIdMapper.createShaderRecordBuffer(
1796-
decl, DeclResultIdMapper::ContextUsageKind::ShaderRecordBufferEXT);
1796+
decl, DeclResultIdMapper::ContextUsageKind::ShaderRecordBufferKHR);
17971797
return;
17981798
}
17991799

tools/clang/test/CodeGenSPIRV_Lit/vk.layout.shader-record-ext.std430.hlsl

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@ struct T {
2525
row_major float3x2 f3[2];
2626
};
2727

28-
// CHECK: OpMemberDecorate %type_ShaderRecordBufferEXT_S 0 Offset 0
29-
// CHECK: OpMemberDecorate %type_ShaderRecordBufferEXT_S 1 Offset 16
30-
// CHECK: OpMemberDecorate %type_ShaderRecordBufferEXT_S 2 Offset 32
31-
// CHECK: OpMemberDecorate %type_ShaderRecordBufferEXT_S 3 Offset 224
32-
// CHECK: OpMemberDecorate %type_ShaderRecordBufferEXT_S 4 Offset 256
33-
// CHECK: OpMemberDecorate %type_ShaderRecordBufferEXT_S 4 MatrixStride 16
34-
// CHECK: OpMemberDecorate %type_ShaderRecordBufferEXT_S 4 ColMajor
28+
// CHECK: OpMemberDecorate %type_ConstantBuffer_S 0 Offset 0
29+
// CHECK: OpMemberDecorate %type_ConstantBuffer_S 1 Offset 16
30+
// CHECK: OpMemberDecorate %type_ConstantBuffer_S 2 Offset 32
31+
// CHECK: OpMemberDecorate %type_ConstantBuffer_S 3 Offset 224
32+
// CHECK: OpMemberDecorate %type_ConstantBuffer_S 4 Offset 256
33+
// CHECK: OpMemberDecorate %type_ConstantBuffer_S 4 MatrixStride 16
34+
// CHECK: OpMemberDecorate %type_ConstantBuffer_S 4 ColMajor
3535

3636
struct S {
3737
float f1;
@@ -44,14 +44,14 @@ struct S {
4444
[[vk::shader_record_ext]]
4545
ConstantBuffer<S> cbuf;
4646

47-
// CHECK: OpDecorate %type_ShaderRecordBufferEXT_S Block
48-
// CHECK: OpMemberDecorate %type_ShaderRecordBufferEXT_block 0 Offset 0
49-
// CHECK: OpMemberDecorate %type_ShaderRecordBufferEXT_block 1 Offset 16
50-
// CHECK: OpMemberDecorate %type_ShaderRecordBufferEXT_block 2 Offset 32
51-
// CHECK: OpMemberDecorate %type_ShaderRecordBufferEXT_block 3 Offset 224
52-
// CHECK: OpMemberDecorate %type_ShaderRecordBufferEXT_block 4 Offset 256
53-
// CHECK: OpMemberDecorate %type_ShaderRecordBufferEXT_block 4 MatrixStride 16
54-
// CHECK: OpMemberDecorate %type_ShaderRecordBufferEXT_block 4 ColMajor
47+
// CHECK: OpDecorate %type_ConstantBuffer_S Block
48+
// CHECK: OpMemberDecorate %type_ShaderRecordBufferKHR_block 0 Offset 0
49+
// CHECK: OpMemberDecorate %type_ShaderRecordBufferKHR_block 1 Offset 16
50+
// CHECK: OpMemberDecorate %type_ShaderRecordBufferKHR_block 2 Offset 32
51+
// CHECK: OpMemberDecorate %type_ShaderRecordBufferKHR_block 3 Offset 224
52+
// CHECK: OpMemberDecorate %type_ShaderRecordBufferKHR_block 4 Offset 256
53+
// CHECK: OpMemberDecorate %type_ShaderRecordBufferKHR_block 4 MatrixStride 16
54+
// CHECK: OpMemberDecorate %type_ShaderRecordBufferKHR_block 4 ColMajor
5555

5656

5757
[[vk::shader_record_ext]]
@@ -63,10 +63,13 @@ cbuffer block {
6363
row_major float2x3 f3;
6464
}
6565

66-
// CHECK: OpDecorate %type_ShaderRecordBufferEXT_block Block
66+
// CHECK: OpDecorate %type_ShaderRecordBufferKHR_block Block
6767
struct Payload { float p; };
6868
struct Attr { float a; };
6969

70+
// CHECK: %_ptr_ShaderRecordBufferNV_type_ConstantBuffer_S = OpTypePointer ShaderRecordBufferNV %type_ConstantBuffer_S
71+
// CHECK: %cbuf = OpVariable %_ptr_ShaderRecordBufferNV_type_ConstantBuffer_S ShaderRecordBufferNV
72+
7073
[shader("closesthit")]
7174
void chs1(inout Payload P, in Attr A) {
7275
P.p = cbuf.f1;

tools/clang/test/CodeGenSPIRV_Lit/vk.layout.shader-record-nv.std430.hlsl

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@ struct T {
2525
row_major float3x2 f3[2];
2626
};
2727

28-
// CHECK: OpMemberDecorate %type_ShaderRecordBufferNV_S 0 Offset 0
29-
// CHECK: OpMemberDecorate %type_ShaderRecordBufferNV_S 1 Offset 16
30-
// CHECK: OpMemberDecorate %type_ShaderRecordBufferNV_S 2 Offset 32
31-
// CHECK: OpMemberDecorate %type_ShaderRecordBufferNV_S 3 Offset 224
32-
// CHECK: OpMemberDecorate %type_ShaderRecordBufferNV_S 4 Offset 256
33-
// CHECK: OpMemberDecorate %type_ShaderRecordBufferNV_S 4 MatrixStride 16
34-
// CHECK: OpMemberDecorate %type_ShaderRecordBufferNV_S 4 ColMajor
28+
// CHECK: OpMemberDecorate %type_ConstantBuffer_S 0 Offset 0
29+
// CHECK: OpMemberDecorate %type_ConstantBuffer_S 1 Offset 16
30+
// CHECK: OpMemberDecorate %type_ConstantBuffer_S 2 Offset 32
31+
// CHECK: OpMemberDecorate %type_ConstantBuffer_S 3 Offset 224
32+
// CHECK: OpMemberDecorate %type_ConstantBuffer_S 4 Offset 256
33+
// CHECK: OpMemberDecorate %type_ConstantBuffer_S 4 MatrixStride 16
34+
// CHECK: OpMemberDecorate %type_ConstantBuffer_S 4 ColMajor
3535

3636
struct S {
3737
float f1;
@@ -44,7 +44,7 @@ struct S {
4444
[[vk::shader_record_nv]]
4545
ConstantBuffer<S> cbuf;
4646

47-
// CHECK: OpDecorate %type_ShaderRecordBufferNV_S Block
47+
// CHECK: OpDecorate %type_ConstantBuffer_S Block
4848
// CHECK: OpMemberDecorate %type_ShaderRecordBufferNV_block 0 Offset 0
4949
// CHECK: OpMemberDecorate %type_ShaderRecordBufferNV_block 1 Offset 16
5050
// CHECK: OpMemberDecorate %type_ShaderRecordBufferNV_block 2 Offset 32
@@ -67,6 +67,9 @@ cbuffer block {
6767
struct Payload { float p; };
6868
struct Attr { float a; };
6969

70+
// CHECK: %_ptr_ShaderRecordBufferNV_type_ConstantBuffer_S = OpTypePointer ShaderRecordBufferNV %type_ConstantBuffer_S
71+
// CHECK: %cbuf = OpVariable %_ptr_ShaderRecordBufferNV_type_ConstantBuffer_S ShaderRecordBufferNV
72+
7073
[shader("closesthit")]
7174
void chs1(inout Payload P, in Attr A) {
7275
P.p = cbuf.f1;
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// RUN: %dxc -T cs_6_0 -E main -fcgl %s -spirv | FileCheck %s
2+
3+
struct Foo
4+
{
5+
float m_x;
6+
};
7+
8+
// CHECK: %g_pc = OpVariable %_ptr_PushConstant_type_ConstantBuffer_Foo PushConstant
9+
[[vk::push_constant]] ConstantBuffer<Foo> g_pc;
10+
RWStructuredBuffer<float> g_buff;
11+
12+
float mul1(Foo m, float4 v)
13+
{
14+
return m.m_x + v.x;
15+
}
16+
17+
[numthreads(1, 1, 1)] void main()
18+
{
19+
// CHECK: OpLoad %type_ConstantBuffer_Foo %g_pc
20+
g_buff[0] = mul1(g_pc, float4(1, 0, 0, 1));
21+
}

tools/clang/test/CodeGenSPIRV_Lit/vk.push-constant.constantbuffer.hlsl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ struct StructA
66
float3 two;
77
};
88

9-
// CHECK: %type_PushConstant_StructA = OpTypeStruct %v3float %v3float
10-
// CHECK: %PushConstants = OpVariable %_ptr_PushConstant_type_PushConstant_StructA PushConstant
9+
// CHECK: %type_ConstantBuffer_StructA = OpTypeStruct %v3float %v3float
10+
// CHECK: %PushConstants = OpVariable %_ptr_PushConstant_type_ConstantBuffer_StructA PushConstant
1111
[[vk::push_constant]] ConstantBuffer<StructA> PushConstants;
1212

1313
void main()
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: %dxc -T lib_6_7 -fspv-target-env=vulkan1.1spirv1.4 -E main -fcgl %s -spirv | FileCheck %s
2+
3+
// CHECK: OpCapability RayTracingKHR
4+
// CHECK: OpExtension "SPV_KHR_ray_query"
5+
// CHECK: OpExtension "SPV_KHR_ray_tracing"
6+
7+
struct Foo
8+
{
9+
float m_x;
10+
};
11+
12+
// CHECK: %g_pc = OpVariable %_ptr_ShaderRecordBufferNV_type_ConstantBuffer_Foo ShaderRecordBufferNV
13+
[[vk::shader_record_ext]] ConstantBuffer<Foo> g_pc;
14+
RWStructuredBuffer<float> g_buff;
15+
16+
float mul1(Foo m, float4 v)
17+
{
18+
return m.m_x + v.x;
19+
}
20+
21+
[shader("raygeneration")] void main()
22+
{
23+
// CHECK: OpLoad %type_ConstantBuffer_Foo %g_pc
24+
g_buff[0] = mul1(g_pc, float4(1, 0, 0, 1));
25+
}

0 commit comments

Comments
 (0)