Skip to content

Commit 9f51858

Browse files
author
joaosaffran
committed
adding metadata support for static samplers
1 parent e8066df commit 9f51858

File tree

4 files changed

+157
-17
lines changed

4 files changed

+157
-17
lines changed

llvm/lib/MC/DXContainerRootSignature.cpp

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "llvm/MC/DXContainerRootSignature.h"
1010
#include "llvm/ADT/SmallString.h"
1111
#include "llvm/Support/EndianStream.h"
12+
#include <cstdint>
1213

1314
using namespace llvm;
1415
using namespace llvm::mcdxbc;
@@ -71,12 +72,16 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
7172
BOS.reserveExtraSpace(getSize());
7273

7374
const uint32_t NumParameters = ParametersContainer.size();
74-
75+
const uint32_t NumSamplers = StaticSamplers.size();
7576
support::endian::write(BOS, Version, llvm::endianness::little);
7677
support::endian::write(BOS, NumParameters, llvm::endianness::little);
7778
support::endian::write(BOS, RootParameterOffset, llvm::endianness::little);
78-
support::endian::write(BOS, NumStaticSamplers, llvm::endianness::little);
79-
support::endian::write(BOS, StaticSamplersOffset, llvm::endianness::little);
79+
support::endian::write(BOS, NumSamplers, llvm::endianness::little);
80+
uint32_t SSO = StaticSamplersOffset;
81+
if (NumSamplers > 0)
82+
SSO = writePlaceholder(BOS);
83+
else
84+
support::endian::write(BOS, SSO, llvm::endianness::little);
8085
support::endian::write(BOS, Flags, llvm::endianness::little);
8186

8287
SmallVector<uint32_t> ParamsOffsets;
@@ -142,20 +147,23 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
142147
}
143148
}
144149
}
145-
for (const auto &S : StaticSamplers) {
146-
support::endian::write(BOS, S.Filter, llvm::endianness::little);
147-
support::endian::write(BOS, S.AddressU, llvm::endianness::little);
148-
support::endian::write(BOS, S.AddressV, llvm::endianness::little);
149-
support::endian::write(BOS, S.AddressW, llvm::endianness::little);
150-
support::endian::write(BOS, S.MipLODBias, llvm::endianness::little);
151-
support::endian::write(BOS, S.MaxAnisotropy, llvm::endianness::little);
152-
support::endian::write(BOS, S.ComparisonFunc, llvm::endianness::little);
153-
support::endian::write(BOS, S.BorderColor, llvm::endianness::little);
154-
support::endian::write(BOS, S.MinLOD, llvm::endianness::little);
155-
support::endian::write(BOS, S.MaxLOD, llvm::endianness::little);
156-
support::endian::write(BOS, S.ShaderRegister, llvm::endianness::little);
157-
support::endian::write(BOS, S.RegisterSpace, llvm::endianness::little);
158-
support::endian::write(BOS, S.ShaderVisibility, llvm::endianness::little);
150+
if (NumSamplers > 0) {
151+
rewriteOffsetToCurrentByte(BOS, SSO);
152+
for (const auto &S : StaticSamplers) {
153+
support::endian::write(BOS, S.Filter, llvm::endianness::little);
154+
support::endian::write(BOS, S.AddressU, llvm::endianness::little);
155+
support::endian::write(BOS, S.AddressV, llvm::endianness::little);
156+
support::endian::write(BOS, S.AddressW, llvm::endianness::little);
157+
support::endian::write(BOS, S.MipLODBias, llvm::endianness::little);
158+
support::endian::write(BOS, S.MaxAnisotropy, llvm::endianness::little);
159+
support::endian::write(BOS, S.ComparisonFunc, llvm::endianness::little);
160+
support::endian::write(BOS, S.BorderColor, llvm::endianness::little);
161+
support::endian::write(BOS, S.MinLOD, llvm::endianness::little);
162+
support::endian::write(BOS, S.MaxLOD, llvm::endianness::little);
163+
support::endian::write(BOS, S.ShaderRegister, llvm::endianness::little);
164+
support::endian::write(BOS, S.RegisterSpace, llvm::endianness::little);
165+
support::endian::write(BOS, S.ShaderVisibility, llvm::endianness::little);
166+
}
159167
}
160168
assert(Storage.size() == getSize());
161169
OS.write(Storage.data(), Storage.size());

llvm/lib/Target/DirectX/DXILRootSignature.cpp

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
//===----------------------------------------------------------------------===//
1313
#include "DXILRootSignature.h"
1414
#include "DirectX.h"
15+
#include "llvm/ADT/APFloat.h"
1516
#include "llvm/ADT/StringSwitch.h"
1617
#include "llvm/ADT/Twine.h"
1718
#include "llvm/Analysis/DXILMetadataAnalysis.h"
@@ -55,6 +56,13 @@ static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
5556
return std::nullopt;
5657
}
5758

59+
static std::optional<APFloat> extractMdFloatValue(MDNode *Node,
60+
unsigned int OpId) {
61+
if (auto *CI = mdconst::dyn_extract<ConstantFP>(Node->getOperand(OpId).get()))
62+
return CI->getValue();
63+
return std::nullopt;
64+
}
65+
5866
static std::optional<StringRef> extractMdStringValue(MDNode *Node,
5967
unsigned int OpId) {
6068
MDString *NodeText = dyn_cast<MDString>(Node->getOperand(OpId));
@@ -262,6 +270,81 @@ static bool parseDescriptorTable(LLVMContext *Ctx,
262270
return false;
263271
}
264272

273+
static bool parseStaticSampler(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
274+
MDNode *StaticSamplerNode) {
275+
if (StaticSamplerNode->getNumOperands() != 14)
276+
return reportError(Ctx, "Invalid format for Static Sampler");
277+
278+
dxbc::RTS0::v1::StaticSampler Sampler;
279+
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1))
280+
Sampler.Filter = *Val;
281+
else
282+
return reportError(Ctx, "Invalid value for Filter");
283+
284+
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 2))
285+
Sampler.AddressU = *Val;
286+
else
287+
return reportError(Ctx, "Invalid value for AddressU");
288+
289+
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 3))
290+
Sampler.AddressV = *Val;
291+
else
292+
return reportError(Ctx, "Invalid value for AddressV");
293+
294+
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 4))
295+
Sampler.AddressW = *Val;
296+
else
297+
return reportError(Ctx, "Invalid value for AddressW");
298+
299+
if (std::optional<APFloat> Val = extractMdFloatValue(StaticSamplerNode, 5))
300+
Sampler.MipLODBias = Val->convertToFloat();
301+
else
302+
return reportError(Ctx, "Invalid value for MipLODBias");
303+
304+
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6))
305+
Sampler.MaxAnisotropy = *Val;
306+
else
307+
return reportError(Ctx, "Invalid value for MaxAnisotropy");
308+
309+
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 7))
310+
Sampler.ComparisonFunc = *Val;
311+
else
312+
return reportError(Ctx, "Invalid value for ComparisonFunc ");
313+
314+
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 8))
315+
Sampler.BorderColor = *Val;
316+
else
317+
return reportError(Ctx, "Invalid value for ComparisonFunc ");
318+
319+
if (std::optional<APFloat> Val = extractMdFloatValue(StaticSamplerNode, 9))
320+
Sampler.MinLOD = Val->convertToFloat();
321+
else
322+
return reportError(Ctx, "Invalid value for MinLOD");
323+
324+
if (std::optional<APFloat> Val = extractMdFloatValue(StaticSamplerNode, 10))
325+
Sampler.MaxLOD = Val->convertToFloat();
326+
else
327+
return reportError(Ctx, "Invalid value for MaxLOD");
328+
329+
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11))
330+
Sampler.ShaderRegister = *Val;
331+
else
332+
return reportError(Ctx, "Invalid value for ShaderRegister");
333+
334+
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 12))
335+
Sampler.RegisterSpace = *Val;
336+
else
337+
return reportError(Ctx, "Invalid value for RegisterSpace");
338+
339+
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 13))
340+
Sampler.ShaderVisibility = *Val;
341+
else
342+
return reportError(Ctx, "Invalid value for ShaderVisibility");
343+
344+
RSD.StaticSamplers.push_back(Sampler);
345+
return false;
346+
}
347+
265348
static bool parseRootSignatureElement(LLVMContext *Ctx,
266349
mcdxbc::RootSignatureDesc &RSD,
267350
MDNode *Element) {
@@ -277,6 +360,7 @@ static bool parseRootSignatureElement(LLVMContext *Ctx,
277360
.Case("RootSRV", RootSignatureElementKind::SRV)
278361
.Case("RootUAV", RootSignatureElementKind::UAV)
279362
.Case("DescriptorTable", RootSignatureElementKind::DescriptorTable)
363+
.Case("StaticSampler", RootSignatureElementKind::StaticSamplers)
280364
.Default(RootSignatureElementKind::Error);
281365

282366
switch (ElementKind) {
@@ -291,6 +375,8 @@ static bool parseRootSignatureElement(LLVMContext *Ctx,
291375
return parseRootDescriptors(Ctx, RSD, Element, ElementKind);
292376
case RootSignatureElementKind::DescriptorTable:
293377
return parseDescriptorTable(Ctx, RSD, Element);
378+
case RootSignatureElementKind::StaticSamplers:
379+
return parseStaticSampler(Ctx, RSD, Element);
294380
case RootSignatureElementKind::Error:
295381
return reportError(Ctx, "Invalid Root Signature Element: " + *ElementText);
296382
}
@@ -522,6 +608,9 @@ analyzeModule(Module &M) {
522608
// offset will always equal to the header size.
523609
RSD.RootParameterOffset = sizeof(dxbc::RTS0::v1::RootSignatureHeader);
524610

611+
// static sampler offset is calculated when writting dxcontainer.
612+
RSD.StaticSamplersOffset = 0u;
613+
525614
if (parse(Ctx, RSD, RootElementListNode) || validate(Ctx, RSD)) {
526615
return RSDMap;
527616
}

llvm/lib/Target/DirectX/DXILRootSignature.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ enum class RootSignatureElementKind {
3232
UAV = 4,
3333
CBV = 5,
3434
DescriptorTable = 6,
35+
StaticSamplers = 7
3536
};
3637
class RootSignatureAnalysis : public AnalysisInfoMixin<RootSignatureAnalysis> {
3738
friend AnalysisInfoMixin<RootSignatureAnalysis>;
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
; RUN: opt %s -dxil-embed -dxil-globals -S -o - | FileCheck %s
2+
; RUN: llc %s --filetype=obj -o - | obj2yaml | FileCheck %s --check-prefix=DXC
3+
4+
target triple = "dxil-unknown-shadermodel6.0-compute"
5+
6+
; CHECK: @dx.rts0 = private constant [76 x i8] c"{{.*}}", section "RTS0", align 4
7+
8+
define void @main() #0 {
9+
entry:
10+
ret void
11+
}
12+
attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
13+
14+
15+
!dx.rootsignatures = !{!2} ; list of function/root signature pairs
16+
!2 = !{ ptr @main, !3 } ; function, root signature
17+
!3 = !{ !5 } ; list of root signature elements
18+
!5 = !{ !"StaticSampler", i32 4, i32 2, i32 3, i32 5, float 0x40403999A0000000, i32 9, i32 3, i32 2, float -1.280000e+02, float 1.280000e+02, i32 42, i32 0, i32 0 }
19+
20+
; DXC: - Name: RTS0
21+
; DXC-NEXT: Size: 76
22+
; DXC-NEXT: RootSignature:
23+
; DXC-NEXT: Version: 2
24+
; DXC-NEXT: NumRootParameters: 0
25+
; DXC-NEXT: RootParametersOffset: 24
26+
; DXC-NEXT: NumStaticSamplers: 1
27+
; DXC-NEXT: StaticSamplersOffset: 24
28+
; DXC-NEXT: Parameters: []
29+
; DXC-NEXT: Samplers:
30+
; DXC-NEXT: - Filter: 4
31+
; DXC-NEXT: AddressU: 2
32+
; DXC-NEXT: AddressV: 3
33+
; DXC-NEXT: AddressW: 5
34+
; DXC-NEXT: MipLODBias: 32.45
35+
; DXC-NEXT: MaxAnisotropy: 9
36+
; DXC-NEXT: ComparisonFunc: 3
37+
; DXC-NEXT: BorderColor: 2
38+
; DXC-NEXT: MinLOD: -128
39+
; DXC-NEXT: MaxLOD: 128
40+
; DXC-NEXT: ShaderRegister: 42
41+
; DXC-NEXT: RegisterSpace: 0
42+
; DXC-NEXT: ShaderVisibility: 0

0 commit comments

Comments
 (0)