Skip to content

Commit 24b9a4e

Browse files
committed
[HLSL][SPIR-V] Implement vk::location for inputs
This commit adds the support for vk::location attribute, focusing on input semantics.
1 parent 9c4c26f commit 24b9a4e

15 files changed

+227
-7
lines changed

clang/include/clang/Basic/Attr.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5172,6 +5172,14 @@ def HLSLVkConstantId : InheritableAttr {
51725172
let Documentation = [VkConstantIdDocs];
51735173
}
51745174

5175+
def HLSLVkLocation : HLSLAnnotationAttr {
5176+
let Spellings = [CXX11<"vk", "location">];
5177+
let Args = [IntArgument<"Location">];
5178+
let Subjects = SubjectList<[ParmVar, Field, Function], ErrorDiag>;
5179+
let LangOpts = [HLSL];
5180+
let Documentation = [HLSLVkLocationDocs];
5181+
}
5182+
51755183
def RandomizeLayout : InheritableAttr {
51765184
let Spellings = [GCC<"randomize_layout">];
51775185
let Subjects = SubjectList<[Record]>;

clang/include/clang/Basic/AttrDocs.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8981,6 +8981,18 @@ The descriptor set is optional and defaults to 0 if not provided.
89818981
}];
89828982
}
89838983

8984+
def HLSLVkLocationDocs : Documentation {
8985+
let Category = DocCatVariable;
8986+
let Content = [{
8987+
Attribute used for specifying the location number for the stage input/output
8988+
variables. Allowed on function parameters, function returns, and struct
8989+
fields. This parameter has no effect when used outside of an entrypoint
8990+
parameter/parameter field/return value.
8991+
8992+
This attribute maps to the 'Location' SPIR-V decoration.
8993+
}];
8994+
}
8995+
89848996
def WebAssemblyFuncrefDocs : Documentation {
89858997
let Category = DocCatType;
89868998
let Content = [{

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13234,6 +13234,9 @@ def err_hlsl_semantic_index_overlap : Error<"semantic index overlap %0">;
1323413234
def err_hlsl_semantic_unsupported_iotype_for_stage
1323513235
: Error<"semantic %0 is unsupported in %2 shaders as %1, requires one of "
1323613236
"the following: %3">;
13237+
def err_hlsl_semantic_partial_explicit_indexing
13238+
: Error<"partial explicit stage input location assignment via "
13239+
"vk::location(X) unsupported">;
1323713240

1323813241
def warn_hlsl_user_defined_type_missing_member: Warning<"binding type '%select{t|u|b|s|c}0' only applies to types containing %select{SRV resources|UAV resources|constant buffer resources|sampler state|numeric types}0">, InGroup<LegacyConstantRegisterBinding>;
1323913242
def err_hlsl_binding_type_mismatch: Error<"binding type '%select{t|u|b|s|c}0' only applies to %select{SRV resources|UAV resources|constant buffer resources|sampler state|numeric variables in the global scope}0">;

clang/include/clang/Sema/SemaHLSL.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ class SemaHLSL : public SemaBase {
168168
void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
169169
void handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL);
170170
void handleVkBindingAttr(Decl *D, const ParsedAttr &AL);
171+
void handleVkLocationAttr(Decl *D, const ParsedAttr &AL);
171172
void handlePackOffsetAttr(Decl *D, const ParsedAttr &AL);
172173
void handleShaderAttr(Decl *D, const ParsedAttr &AL);
173174
void handleResourceBindingAttr(Decl *D, const ParsedAttr &AL);
@@ -240,6 +241,8 @@ class SemaHLSL : public SemaBase {
240241
HLSLParsedSemanticAttr *Semantic;
241242
std::optional<uint32_t> Index;
242243
};
244+
std::optional<bool> InputUsesExplicitVkLocations = std::nullopt;
245+
std::optional<bool> OutputUsesExplicitVkLocations = std::nullopt;
243246

244247
enum IOType {
245248
In = 0b01,

clang/lib/CodeGen/CGHLSLRuntime.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -582,20 +582,22 @@ static llvm::Value *createSPIRVLocationLoad(IRBuilder<> &B, llvm::Module &M,
582582
return B.CreateLoad(Ty, GV);
583583
}
584584

585-
llvm::Value *
586-
CGHLSLRuntime::emitSPIRVUserSemanticLoad(llvm::IRBuilder<> &B, llvm::Type *Type,
587-
HLSLAppliedSemanticAttr *Semantic,
588-
std::optional<unsigned> Index) {
585+
llvm::Value *CGHLSLRuntime::emitSPIRVUserSemanticLoad(
586+
llvm::IRBuilder<> &B, llvm::Type *Type, const clang::DeclaratorDecl *Decl,
587+
HLSLAppliedSemanticAttr *Semantic, std::optional<unsigned> Index) {
589588
Twine BaseName = Twine(Semantic->getAttrName()->getName());
590589
Twine VariableName = BaseName.concat(Twine(Index.value_or(0)));
591590

592591
unsigned Location = SPIRVLastAssignedInputSemanticLocation;
592+
if (auto *L = Decl->getAttr<HLSLVkLocationAttr>())
593+
Location = L->getLocation();
593594

594595
// DXC completely ignores the semantic/index pair. Location are assigned from
595596
// the first semantic to the last.
596597
llvm::ArrayType *AT = dyn_cast<llvm::ArrayType>(Type);
597598
unsigned ElementCount = AT ? AT->getNumElements() : 1;
598599
SPIRVLastAssignedInputSemanticLocation += ElementCount;
600+
599601
return createSPIRVLocationLoad(B, CGM.getModule(), Type, Location,
600602
VariableName.str());
601603
}
@@ -616,10 +618,14 @@ static void createSPIRVLocationStore(IRBuilder<> &B, llvm::Module &M,
616618

617619
void CGHLSLRuntime::emitSPIRVUserSemanticStore(
618620
llvm::IRBuilder<> &B, llvm::Value *Source,
619-
HLSLAppliedSemanticAttr *Semantic, std::optional<unsigned> Index) {
621+
const clang::DeclaratorDecl *Decl, HLSLAppliedSemanticAttr *Semantic,
622+
std::optional<unsigned> Index) {
620623
Twine BaseName = Twine(Semantic->getAttrName()->getName());
621624
Twine VariableName = BaseName.concat(Twine(Index.value_or(0)));
625+
622626
unsigned Location = SPIRVLastAssignedOutputSemanticLocation;
627+
if (auto *L = Decl->getAttr<HLSLVkLocationAttr>())
628+
Location = L->getLocation();
623629

624630
// DXC completely ignores the semantic/index pair. Location are assigned from
625631
// the first semantic to the last.
@@ -671,7 +677,7 @@ llvm::Value *CGHLSLRuntime::emitUserSemanticLoad(
671677
IRBuilder<> &B, llvm::Type *Type, const clang::DeclaratorDecl *Decl,
672678
HLSLAppliedSemanticAttr *Semantic, std::optional<unsigned> Index) {
673679
if (CGM.getTarget().getTriple().isSPIRV())
674-
return emitSPIRVUserSemanticLoad(B, Type, Semantic, Index);
680+
return emitSPIRVUserSemanticLoad(B, Type, Decl, Semantic, Index);
675681

676682
if (CGM.getTarget().getTriple().isDXIL())
677683
return emitDXILUserSemanticLoad(B, Type, Semantic, Index);
@@ -684,7 +690,7 @@ void CGHLSLRuntime::emitUserSemanticStore(IRBuilder<> &B, llvm::Value *Source,
684690
HLSLAppliedSemanticAttr *Semantic,
685691
std::optional<unsigned> Index) {
686692
if (CGM.getTarget().getTriple().isSPIRV())
687-
return emitSPIRVUserSemanticStore(B, Source, Semantic, Index);
693+
return emitSPIRVUserSemanticStore(B, Source, Decl, Semantic, Index);
688694

689695
if (CGM.getTarget().getTriple().isDXIL())
690696
return emitDXILUserSemanticStore(B, Source, Semantic, Index);

clang/lib/CodeGen/CGHLSLRuntime.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ class CGHLSLRuntime {
278278
HLSLResourceBindingAttr *RBA);
279279

280280
llvm::Value *emitSPIRVUserSemanticLoad(llvm::IRBuilder<> &B, llvm::Type *Type,
281+
const clang::DeclaratorDecl *Decl,
281282
HLSLAppliedSemanticAttr *Semantic,
282283
std::optional<unsigned> Index);
283284
llvm::Value *emitDXILUserSemanticLoad(llvm::IRBuilder<> &B, llvm::Type *Type,
@@ -289,6 +290,7 @@ class CGHLSLRuntime {
289290
std::optional<unsigned> Index);
290291

291292
void emitSPIRVUserSemanticStore(llvm::IRBuilder<> &B, llvm::Value *Source,
293+
const clang::DeclaratorDecl *Decl,
292294
HLSLAppliedSemanticAttr *Semantic,
293295
std::optional<unsigned> Index);
294296
void emitDXILUserSemanticStore(llvm::IRBuilder<> &B, llvm::Value *Source,

clang/lib/Sema/SemaDeclAttr.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7703,6 +7703,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
77037703
case ParsedAttr::AT_HLSLUnparsedSemantic:
77047704
S.HLSL().handleSemanticAttr(D, AL);
77057705
break;
7706+
case ParsedAttr::AT_HLSLVkLocation:
7707+
S.HLSL().handleVkLocationAttr(D, AL);
7708+
break;
77067709

77077710
case ParsedAttr::AT_AbiTag:
77087711
handleAbiTagAttr(S, D, AL);

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -771,6 +771,22 @@ void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
771771
}
772772
}
773773

774+
static bool isPipelineBuiltin(const ASTContext &AstContext, FunctionDecl *FD,
775+
HLSLAppliedSemanticAttr *Semantic) {
776+
if (AstContext.getTargetInfo().getTriple().getOS() != llvm::Triple::Vulkan)
777+
return false;
778+
779+
const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
780+
assert(ShaderAttr && "Entry point has no shader attribute");
781+
llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
782+
auto SemanticName = Semantic->getSemanticName().upper();
783+
784+
if (ST == llvm::Triple::Pixel && SemanticName == "SV_POSITION")
785+
return true;
786+
787+
return false;
788+
}
789+
774790
bool SemaHLSL::determineActiveSemanticOnScalar(FunctionDecl *FD,
775791
DeclaratorDecl *OutputDecl,
776792
DeclaratorDecl *D,
@@ -800,6 +816,22 @@ bool SemaHLSL::determineActiveSemanticOnScalar(FunctionDecl *FD,
800816

801817
unsigned Location = ActiveSemantic.Index.value_or(0);
802818

819+
if (!isPipelineBuiltin(getASTContext(), FD, A)) {
820+
bool HasVkLocation = false;
821+
if (auto *A = D->getAttr<HLSLVkLocationAttr>()) {
822+
HasVkLocation = true;
823+
Location = A->getLocation();
824+
}
825+
826+
auto &UsesExplicitVkLocations =
827+
IsInput ? InputUsesExplicitVkLocations : OutputUsesExplicitVkLocations;
828+
if (UsesExplicitVkLocations.value_or(HasVkLocation) != HasVkLocation) {
829+
Diag(D->getLocation(), diag::err_hlsl_semantic_partial_explicit_indexing);
830+
return false;
831+
}
832+
UsesExplicitVkLocations = HasVkLocation;
833+
}
834+
803835
const ConstantArrayType *AT = dyn_cast<ConstantArrayType>(D->getType());
804836
unsigned ElementCount = AT ? AT->getZExtSize() : 1;
805837
ActiveSemantic.Index = Location + ElementCount;
@@ -1757,6 +1789,15 @@ void SemaHLSL::handleVkBindingAttr(Decl *D, const ParsedAttr &AL) {
17571789
HLSLVkBindingAttr(getASTContext(), AL, Binding, Set));
17581790
}
17591791

1792+
void SemaHLSL::handleVkLocationAttr(Decl *D, const ParsedAttr &AL) {
1793+
uint32_t Location;
1794+
if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Location))
1795+
return;
1796+
1797+
D->addAttr(::new (getASTContext())
1798+
HLSLVkLocationAttr(getASTContext(), AL, Location));
1799+
}
1800+
17601801
bool SemaHLSL::diagnoseInputIDType(QualType T, const ParsedAttr &AL) {
17611802
const auto *VT = T->getAs<VectorType>();
17621803

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// RUN: %clang_cc1 -triple spirv-pc-vulkan1.3-pixel -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
2+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-pixel -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
3+
4+
// CHECK-SPIRV: @SV_Position = external hidden thread_local addrspace(7) externally_initialized constant <4 x float>, !spirv.Decorations ![[#MD_0:]]
5+
// CHECK-SPIRV: @SV_Target0 = external hidden thread_local addrspace(8) global <4 x float>, !spirv.Decorations ![[#MD_2:]]
6+
7+
struct Output {
8+
[[vk::location(2)]] float4 field : SV_Target;
9+
};
10+
11+
// CHECK: define void @main() {{.*}} {
12+
Output main(float4 p : SV_Position) {
13+
// CHECK: %[[#OUT:]] = alloca %struct.Output, align 16
14+
15+
// CHECK-SPIRV: %[[#IN:]] = load <4 x float>, ptr addrspace(7) @SV_Position, align 16
16+
// CHECK-SPIRV: call spir_func void @_Z4mainDv4_f(ptr %[[#OUT]], <4 x float> %[[#IN]])
17+
18+
// CHECK-DXIL: call void @_Z4mainDv4_f(ptr %[[#OUT]], <4 x float> %SV_Position0)
19+
20+
// CHECK: %[[#TMP:]] = load %struct.Output, ptr %[[#OUT]], align 16
21+
// CHECK: %[[#FIELD:]] = extractvalue %struct.Output %[[#TMP]], 0
22+
23+
// CHECK-SPIRV: store <4 x float> %[[#FIELD]], ptr addrspace(8) @SV_Target0, align 16
24+
// CHECK-DXIL: call void @llvm.dx.store.output.v4f32(i32 4, i32 0, i32 0, i8 0, i32 poison, <4 x float> %[[#FIELD]])
25+
Output o;
26+
o.field = p;
27+
return o;
28+
}
29+
30+
// CHECK-SPIRV-DAG: ![[#MD_0]] = !{![[#MD_1:]]}
31+
// CHECK-SPIRV-DAG: ![[#MD_1]] = !{i32 11, i32 15}
32+
// | `-> BuiltIn 'FragCoord'
33+
// `-> SPIR-V decoration 'BuiltIn'
34+
// CHECK-SPIRV-DAG: ![[#MD_2]] = !{![[#MD_3:]]}
35+
// CHECK-SPIRV-DAG: ![[#MD_3]] = !{i32 30, i32 2}
36+
// | `-> Location index
37+
// `-> SPIR-V decoration 'Location'
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// RUN: %clang_cc1 -triple spirv-pc-vulkan1.3-pixel -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s --check-prefix=CHECK-SPIRV
2+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-pixel -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s --check-prefix=CHECK-DXIL
3+
4+
// CHECK-SPIRV: @SV_Target0 = external hidden thread_local addrspace(8) global <4 x float>, !spirv.Decorations ![[#MD_2:]]
5+
6+
// CHECK: define void @main() {{.*}} {
7+
[[vk::location(2)]] float4 main(float4 p : SV_Position) : SV_Target {
8+
// CHECK-SPIRV: %[[#R:]] = call spir_func <4 x float> @_Z4mainDv4_f(<4 x float> %[[#]])
9+
// CHECK-SPIRV: store <4 x float> %[[#R]], ptr addrspace(8) @SV_Target0, align 16
10+
11+
// CHECK-DXIL: %[[#TMP:]] = call <4 x float> @_Z4mainDv4_f(<4 x float> %SV_Position0)
12+
// CHECK-DXIL: call void @llvm.dx.store.output.v4f32(i32 4, i32 0, i32 0, i8 0, i32 poison, <4 x float> %[[#TMP]])
13+
return p;
14+
}
15+
16+
// CHECK-SPIRV-DAG: ![[#MD_2]] = !{![[#MD_3:]]}
17+
// CHECK-SPIRV-DAG: ![[#MD_3]] = !{i32 30, i32 2}
18+
// | `-> Location index
19+
// `-> SPIR-V decoration 'Location'

0 commit comments

Comments
 (0)