Skip to content

Commit 9dacbe3

Browse files
[SPIR-V] Implement vk::ext_builtin_input and vk::ext_builtin_output (microsoft#6027)
I definitely think it would look better if we allowed these attributes on variables, ie microsoft/hlsl-specs#76. I haven't fully investigated how involved it would be to implement, but my intuition is that it wouldn't take that much more work. Fixes microsoft#4217.
1 parent a743e97 commit 9dacbe3

13 files changed

+256
-29
lines changed

tools/clang/include/clang/Basic/Attr.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1267,6 +1267,22 @@ def VKBuiltIn : InheritableAttr {
12671267
let Documentation = [Undocumented];
12681268
}
12691269

1270+
def VKExtBuiltinInput : InheritableAttr {
1271+
let Spellings = [CXX11<"vk", "ext_builtin_input">];
1272+
let Subjects = SubjectList<[Var], ErrorDiag>;
1273+
let Args = [IntArgument<"BuiltInID">];
1274+
let LangOpts = [SPIRV];
1275+
let Documentation = [Undocumented];
1276+
}
1277+
1278+
def VKExtBuiltinOutput : InheritableAttr {
1279+
let Spellings = [CXX11<"vk", "ext_builtin_output">];
1280+
let Subjects = SubjectList<[Var], ErrorDiag>;
1281+
let Args = [IntArgument<"BuiltInID">];
1282+
let LangOpts = [SPIRV];
1283+
let Documentation = [Undocumented];
1284+
}
1285+
12701286
def VKLocation : InheritableAttr {
12711287
let Spellings = [CXX11<"vk", "location">];
12721288
let Subjects = SubjectList<[Function, ParmVar, Field], ErrorDiag>;

tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

Lines changed: 53 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -968,6 +968,15 @@ SpirvInstruction *DeclResultIdMapper::getDeclEvalInfo(const ValueDecl *decl,
968968
/* spvArgs */ {}, /* isInst */ false,
969969
loc);
970970
}
971+
972+
if (auto *builtinAttr = decl->getAttr<VKExtBuiltinInputAttr>()) {
973+
return getBuiltinVar(spv::BuiltIn(builtinAttr->getBuiltInID()),
974+
decl->getType(), spv::StorageClass::Input, loc);
975+
} else if (auto *builtinAttr = decl->getAttr<VKExtBuiltinOutputAttr>()) {
976+
return getBuiltinVar(spv::BuiltIn(builtinAttr->getBuiltInID()),
977+
decl->getType(), spv::StorageClass::Output, loc);
978+
}
979+
971980
if (hlsl::IsHLSLDynamicResourceType(decl->getType()) ||
972981
hlsl::IsHLSLDynamicSamplerType(decl->getType())) {
973982
emitError("HLSL object %0 not yet supported with -spirv",
@@ -3808,23 +3817,60 @@ void DeclResultIdMapper::decorateInterpolationMode(
38083817

38093818
SpirvVariable *DeclResultIdMapper::getBuiltinVar(spv::BuiltIn builtIn,
38103819
QualType type,
3820+
spv::StorageClass sc,
38113821
SourceLocation loc) {
38123822
// Guarantee uniqueness
38133823
uint32_t spvBuiltinId = static_cast<uint32_t>(builtIn);
38143824
const auto builtInVar = builtinToVarMap.find(spvBuiltinId);
38153825
if (builtInVar != builtinToVarMap.end()) {
38163826
return builtInVar->second;
38173827
}
3818-
bool mayNeedFlatDecoration = false;
3828+
switch (builtIn) {
3829+
case spv::BuiltIn::HelperInvocation:
3830+
case spv::BuiltIn::SubgroupSize:
3831+
case spv::BuiltIn::SubgroupLocalInvocationId:
3832+
needsLegalization = true;
3833+
break;
3834+
}
3835+
3836+
// Create a dummy StageVar for this builtin variable
3837+
auto var = spvBuilder.addStageBuiltinVar(type, sc, builtIn,
3838+
/*isPrecise*/ false, loc);
3839+
3840+
if (spvContext.isPS() && sc == spv::StorageClass::Input) {
3841+
if (isUintOrVecMatOfUintType(type) || isSintOrVecMatOfSintType(type) ||
3842+
isBoolOrVecMatOfBoolType(type)) {
3843+
spvBuilder.decorateFlat(var, loc);
3844+
}
3845+
}
3846+
3847+
const hlsl::SigPoint *sigPoint =
3848+
hlsl::SigPoint::GetSigPoint(hlsl::SigPointFromInputQual(
3849+
hlsl::DxilParamInputQual::In, spvContext.getCurrentShaderModelKind(),
3850+
/*isPatchConstant=*/false));
3851+
3852+
StageVar stageVar(sigPoint, /*semaInfo=*/{}, /*builtinAttr=*/nullptr, type,
3853+
/*locAndComponentCount=*/{0, 0, false});
3854+
3855+
stageVar.setIsSpirvBuiltin();
3856+
stageVar.setSpirvInstr(var);
3857+
stageVars.push_back(stageVar);
3858+
3859+
// Store in map for re-use
3860+
builtinToVarMap[spvBuiltinId] = var;
3861+
return var;
3862+
}
3863+
3864+
SpirvVariable *DeclResultIdMapper::getBuiltinVar(spv::BuiltIn builtIn,
3865+
QualType type,
3866+
SourceLocation loc) {
38193867
spv::StorageClass sc = spv::StorageClass::Max;
3868+
38203869
// Valid builtins supported
38213870
switch (builtIn) {
38223871
case spv::BuiltIn::HelperInvocation:
38233872
case spv::BuiltIn::SubgroupSize:
38243873
case spv::BuiltIn::SubgroupLocalInvocationId:
3825-
needsLegalization = true;
3826-
mayNeedFlatDecoration = true;
3827-
LLVM_FALLTHROUGH;
38283874
case spv::BuiltIn::HitTNV:
38293875
case spv::BuiltIn::RayTmaxNV:
38303876
case spv::BuiltIn::RayTminNV:
@@ -3857,32 +3903,11 @@ SpirvVariable *DeclResultIdMapper::getBuiltinVar(spv::BuiltIn builtIn,
38573903
sc = spv::StorageClass::Output;
38583904
break;
38593905
default:
3860-
assert(false && "unsupported SPIR-V builtin");
3861-
return nullptr;
3862-
}
3863-
3864-
// Create a dummy StageVar for this builtin variable
3865-
auto var = spvBuilder.addStageBuiltinVar(type, sc, builtIn,
3866-
/*isPrecise*/ false, loc);
3867-
if (mayNeedFlatDecoration && spvContext.isPS()) {
3868-
spvBuilder.decorateFlat(var, loc);
3906+
assert(false && "cannot infer storage class for SPIR-V builtin");
3907+
break;
38693908
}
38703909

3871-
const hlsl::SigPoint *sigPoint =
3872-
hlsl::SigPoint::GetSigPoint(hlsl::SigPointFromInputQual(
3873-
hlsl::DxilParamInputQual::In, spvContext.getCurrentShaderModelKind(),
3874-
/*isPatchConstant=*/false));
3875-
3876-
StageVar stageVar(sigPoint, /*semaInfo=*/{}, /*builtinAttr=*/nullptr, type,
3877-
/*locAndComponentCount=*/{0, 0, false});
3878-
3879-
stageVar.setIsSpirvBuiltin();
3880-
stageVar.setSpirvInstr(var);
3881-
stageVars.push_back(stageVar);
3882-
3883-
// Store in map for re-use
3884-
builtinToVarMap[spvBuiltinId] = var;
3885-
return var;
3910+
return getBuiltinVar(builtIn, type, sc, loc);
38863911
}
38873912

38883913
SpirvVariable *DeclResultIdMapper::createSpirvStageVar(

tools/clang/lib/SPIRV/DeclResultIdMapper.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,13 @@ class DeclResultIdMapper {
195195
FeatureManager &features,
196196
const SpirvCodeGenOptions &spirvOptions);
197197

198-
/// \brief Returns the SPIR-V builtin variable.
198+
/// \brief Returns the SPIR-V builtin variable. Uses sc as default storage
199+
/// class.
200+
SpirvVariable *getBuiltinVar(spv::BuiltIn builtIn, QualType type,
201+
spv::StorageClass sc, SourceLocation);
202+
203+
/// \brief Returns the SPIR-V builtin variable. Tries to infer storage class
204+
/// from the builtin.
199205
SpirvVariable *getBuiltinVar(spv::BuiltIn builtIn, QualType type,
200206
SourceLocation);
201207

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1645,6 +1645,39 @@ bool SpirvEmitter::validateVKAttributes(const NamedDecl *decl) {
16451645
}
16461646
}
16471647

1648+
// a VarDecl should have only one of vk::ext_builtin_input or
1649+
// vk::ext_builtin_output
1650+
if (decl->hasAttr<VKExtBuiltinInputAttr>() &&
1651+
decl->hasAttr<VKExtBuiltinOutputAttr>()) {
1652+
emitError("vk::ext_builtin_input cannot be used together with "
1653+
"vk::ext_builtin_output",
1654+
decl->getAttr<VKExtBuiltinOutputAttr>()->getLocation());
1655+
success = false;
1656+
}
1657+
1658+
// vk::ext_builtin_input and vk::ext_builtin_output must only be used for a
1659+
// static variable. We only allow them to be attached to variables, so it
1660+
// should be fine to cast here.
1661+
if ((decl->hasAttr<VKExtBuiltinInputAttr>() ||
1662+
decl->hasAttr<VKExtBuiltinOutputAttr>()) &&
1663+
cast<VarDecl>(decl)->getStorageClass() != StorageClass::SC_Static) {
1664+
emitError("vk::ext_builtin_input and vk::ext_builtin_output can only be "
1665+
"applied to a static variable",
1666+
decl->getLocation());
1667+
success = false;
1668+
}
1669+
1670+
// vk::ext_builtin_input and vk::ext_builtin_output must only be used for a
1671+
// static variable. We only allow them to be attached to variables, so it
1672+
// should be fine to cast here.
1673+
if (decl->hasAttr<VKExtBuiltinInputAttr>() &&
1674+
!cast<VarDecl>(decl)->getType().isConstQualified()) {
1675+
emitError("vk::ext_builtin_input can only be applied to a const-qualified "
1676+
"variable",
1677+
decl->getLocation());
1678+
success = false;
1679+
}
1680+
16481681
return success;
16491682
}
16501683

@@ -1799,6 +1832,38 @@ void SpirvEmitter::doVarDecl(const VarDecl *decl) {
17991832
return;
18001833
}
18011834

1835+
// Handle vk::ext_builtin_input and vk::ext_builtin_input by using
1836+
// getBuiltinVar to create the builtin and validate the storage class
1837+
if (decl->hasAttr<VKExtBuiltinInputAttr>()) {
1838+
auto *builtinAttr = decl->getAttr<VKExtBuiltinInputAttr>();
1839+
int builtinId = builtinAttr->getBuiltInID();
1840+
SpirvVariable *builtinVar =
1841+
declIdMapper.getBuiltinVar(spv::BuiltIn(builtinId), decl->getType(),
1842+
spv::StorageClass::Input, loc);
1843+
if (builtinVar->getStorageClass() != spv::StorageClass::Input) {
1844+
emitError("cannot redefine builtin %0 as an input",
1845+
builtinAttr->getLocation())
1846+
<< builtinId;
1847+
emitWarning("previous definition is here",
1848+
builtinVar->getSourceLocation());
1849+
}
1850+
return;
1851+
} else if (decl->hasAttr<VKExtBuiltinOutputAttr>()) {
1852+
auto *builtinAttr = decl->getAttr<VKExtBuiltinOutputAttr>();
1853+
int builtinId = builtinAttr->getBuiltInID();
1854+
SpirvVariable *builtinVar =
1855+
declIdMapper.getBuiltinVar(spv::BuiltIn(builtinId), decl->getType(),
1856+
spv::StorageClass::Output, loc);
1857+
if (builtinVar->getStorageClass() != spv::StorageClass::Output) {
1858+
emitError("cannot redefine builtin %0 as an output",
1859+
builtinAttr->getLocation())
1860+
<< builtinId;
1861+
emitWarning("previous definition is here",
1862+
builtinVar->getSourceLocation());
1863+
}
1864+
return;
1865+
}
1866+
18021867
// We can have VarDecls inside cbuffer/tbuffer. For those VarDecls, we need
18031868
// to emit their cbuffer/tbuffer as a whole and access each individual one
18041869
// using access chains.

tools/clang/lib/Sema/SemaHLSL.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13585,6 +13585,16 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A,
1358513585
"DrawIndex,DeviceIndex,ViewportMaskNV"),
1358613586
A.getAttributeSpellingListIndex());
1358713587
break;
13588+
case AttributeList::AT_VKExtBuiltinInput:
13589+
declAttr = ::new (S.Context) VKExtBuiltinInputAttr(
13590+
A.getRange(), S.Context, ValidateAttributeIntArg(S, A),
13591+
A.getAttributeSpellingListIndex());
13592+
break;
13593+
case AttributeList::AT_VKExtBuiltinOutput:
13594+
declAttr = ::new (S.Context) VKExtBuiltinOutputAttr(
13595+
A.getRange(), S.Context, ValidateAttributeIntArg(S, A),
13596+
A.getAttributeSpellingListIndex());
13597+
break;
1358813598
case AttributeList::AT_VKLocation:
1358913599
declAttr = ::new (S.Context)
1359013600
VKLocationAttr(A.getRange(), S.Context, ValidateAttributeIntArg(S, A),
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// RUN: not %dxc -T ps_6_0 -E main -fcgl %s -spirv 2>&1 | FileCheck %s
2+
3+
// CHECK: error: vk::ext_builtin_input cannot be used together with vk::ext_builtin_output
4+
[[vk::ext_builtin_input(/* NumWorkgroups */ 24)]]
5+
[[vk::ext_builtin_output(/* NumWorkgroups */ 24)]]
6+
static uint3 invalid;
7+
8+
void main() {
9+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// RUN: %dxc -T ps_6_0 -E main -fcgl %s -spirv | FileCheck %s
2+
3+
// CHECK: OpEntryPoint Fragment %main "main" %gl_SampleID
4+
// CHECK: OpDecorate %gl_SampleID BuiltIn SampleId
5+
// CHECK: OpDecorate %gl_SampleID Flat
6+
7+
// CHECK: %gl_SampleID = OpVariable %_ptr_Input_int Input
8+
9+
[[vk::ext_builtin_input(/* SampleID */ 18)]]
10+
static const int gl_SampleID;
11+
12+
void main() {
13+
// CHECK: {{%[0-9]+}} = OpLoad %int %gl_SampleID
14+
int sID = gl_SampleID;
15+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: %dxc -T cs_6_0 -E main -fcgl %s -spirv | FileCheck %s
2+
3+
// CHECK: OpEntryPoint GLCompute %main "main" %gl_NumWorkGroups
4+
// CHECK: OpDecorate %gl_NumWorkGroups BuiltIn NumWorkgroups
5+
6+
// CHECK: %gl_NumWorkGroups = OpVariable %_ptr_Input_v3uint Input
7+
8+
[[vk::ext_builtin_input(/* NumWorkgroups */ 24)]]
9+
static const uint3 gl_NumWorkGroups;
10+
11+
uint square_x(uint3 v) {
12+
return v.x * v.x;
13+
}
14+
15+
[numthreads(32,1,1)]
16+
void main() {
17+
// CHECK: {{%[0-9]+}} = OpLoad %v3uint %gl_NumWorkGroups
18+
uint3 numWorkgroups = gl_NumWorkGroups;
19+
// CHECK: [[nwg:%[0-9]+]] = OpLoad %v3uint %gl_NumWorkGroups
20+
// CHECK: OpStore %param_var_v [[nwg]]
21+
// CHECK: OpFunctionCall %uint %square_x %param_var_v
22+
square_x(gl_NumWorkGroups);
23+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
// RUN: not %dxc -T ps_6_0 -E main -fcgl %s -spirv 2>&1 | FileCheck %s
2+
3+
// CHECK: error: vk::ext_builtin_input can only be applied to a const-qualified variable
4+
[[vk::ext_builtin_input(/* NumWorkgroups */ 24)]]
5+
static uint3 invalid;
6+
7+
void main() {
8+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
// RUN: not %dxc -T ps_6_0 -E main -fcgl %s -spirv 2>&1 | FileCheck %s
2+
3+
// CHECK: error: vk::ext_builtin_input and vk::ext_builtin_output can only be applied to a static variable
4+
[[vk::ext_builtin_input(/* NumWorkgroups */ 24)]]
5+
uint3 invalid;
6+
7+
void main() {
8+
}

0 commit comments

Comments
 (0)