Skip to content

Commit 793019a

Browse files
committed
[DirectX] Emit WaveSize function attribute metadata
1 parent ad29838 commit 793019a

File tree

5 files changed

+174
-14
lines changed

5 files changed

+174
-14
lines changed

llvm/include/llvm/Analysis/DXILMetadataAnalysis.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ struct EntryProperties {
2727
unsigned NumThreadsX{0}; // X component
2828
unsigned NumThreadsY{0}; // Y component
2929
unsigned NumThreadsZ{0}; // Z component
30+
unsigned WaveSizeMin{0}; // Minimum component
31+
unsigned WaveSizeMax{0}; // Maximum component
32+
unsigned WaveSizePref{0}; // Preferred component
3033

3134
EntryProperties(const Function *Fn = nullptr) : Entry(Fn) {};
3235
};

llvm/lib/Analysis/DXILMetadataAnalysis.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,22 @@ static ModuleMetadataInfo collectMetadataInfo(Module &M) {
6666
Success = llvm::to_integer(NumThreadsVec[2], EFP.NumThreadsZ, 10);
6767
assert(Success && "Failed to parse Z component of numthreads");
6868
}
69+
// Get wavesize attribute value, if one exists
70+
StringRef WaveSizeStr =
71+
F.getFnAttribute("hlsl.wavesize").getValueAsString();
72+
if (!WaveSizeStr.empty()) {
73+
SmallVector<StringRef> WaveSizeVec;
74+
WaveSizeStr.split(WaveSizeVec, ',');
75+
assert(WaveSizeVec.size() == 3 && "Invalid wavesize specified");
76+
// Read in the three component values of numthreads
77+
[[maybe_unused]] bool Success =
78+
llvm::to_integer(WaveSizeVec[0], EFP.WaveSizeMin, 10);
79+
assert(Success && "Failed to parse Min component of wavesize");
80+
Success = llvm::to_integer(WaveSizeVec[1], EFP.WaveSizeMax, 10);
81+
assert(Success && "Failed to parse Max component of wavesize");
82+
Success = llvm::to_integer(WaveSizeVec[2], EFP.WaveSizePref, 10);
83+
assert(Success && "Failed to parse Preferred component of wavesize");
84+
}
6985
MMDAI.EntryPropertyVec.push_back(EFP);
7086
}
7187
return MMDAI;

llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ enum class EntryPropsTag {
8282
ASStateTag,
8383
WaveSize,
8484
EntryRootSig,
85+
WaveRange = 23,
8586
};
8687

8788
} // namespace
@@ -177,30 +178,32 @@ getTagValueAsMetadata(EntryPropsTag Tag, uint64_t Value, LLVMContext &Ctx) {
177178
case EntryPropsTag::ASStateTag:
178179
case EntryPropsTag::WaveSize:
179180
case EntryPropsTag::EntryRootSig:
181+
case EntryPropsTag::WaveRange:
180182
llvm_unreachable("NYI: Unhandled entry property tag");
181183
}
182184
return MDVals;
183185
}
184186

185-
static MDTuple *
186-
getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
187-
const Triple::EnvironmentType ShaderProfile) {
187+
static MDTuple *getEntryPropAsMetadata(Module &M, const EntryProperties &EP,
188+
uint64_t EntryShaderFlags,
189+
const ModuleMetadataInfo &MMDI) {
188190
SmallVector<Metadata *> MDVals;
189191
LLVMContext &Ctx = EP.Entry->getContext();
190192
if (EntryShaderFlags != 0)
191193
MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderFlags,
192-
EntryShaderFlags, Ctx));
194+
MMDI.ShaderProfile, Ctx));
193195

194196
if (EP.Entry != nullptr) {
195197
// FIXME: support more props.
196198
// See https://github.com/llvm/llvm-project/issues/57948.
197199
// Add shader kind for lib entries.
198-
if (ShaderProfile == Triple::EnvironmentType::Library &&
200+
if (MMDI.ShaderProfile == Triple::EnvironmentType::Library &&
199201
EP.ShaderStage != Triple::EnvironmentType::Library)
200202
MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderKind,
201203
getShaderStage(EP.ShaderStage), Ctx));
202204

203205
if (EP.ShaderStage == Triple::EnvironmentType::Compute) {
206+
// Handle mandatory "hlsl.numthreads"
204207
MDVals.emplace_back(ConstantAsMetadata::get(ConstantInt::get(
205208
Type::getInt32Ty(Ctx), static_cast<int>(EntryPropsTag::NumThreads))));
206209
Metadata *NumThreadVals[] = {ConstantAsMetadata::get(ConstantInt::get(
@@ -210,8 +213,47 @@ getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
210213
ConstantAsMetadata::get(ConstantInt::get(
211214
Type::getInt32Ty(Ctx), EP.NumThreadsZ))};
212215
MDVals.emplace_back(MDNode::get(Ctx, NumThreadVals));
216+
217+
// Handle optional "hlsl.wavesize". The fields are optionally represented
218+
// if they are non-zero.
219+
if (EP.WaveSizeMin != 0) {
220+
bool IsWaveRange = VersionTuple(6, 8) <= MMDI.ShaderModelVersion;
221+
bool IsWaveSize =
222+
!IsWaveRange && VersionTuple(6, 6) <= MMDI.ShaderModelVersion;
223+
224+
if (!IsWaveRange && !IsWaveSize) {
225+
reportError(M, "Shader model 6.6 or greater is required to specify "
226+
"the \"hlsl.wavesize\" function attribute");
227+
return nullptr;
228+
}
229+
230+
if (EP.WaveSizeMax && !IsWaveRange) {
231+
reportError(
232+
M, "Shader model 6.8 or greater is required to specify "
233+
"wave size range values of the \"hlsl.wavesize\" function "
234+
"attribute");
235+
return nullptr;
236+
}
237+
238+
EntryPropsTag Tag =
239+
IsWaveSize ? EntryPropsTag::WaveSize : EntryPropsTag::WaveRange;
240+
MDVals.emplace_back(ConstantAsMetadata::get(
241+
ConstantInt::get(Type::getInt32Ty(Ctx), static_cast<int>(Tag))));
242+
243+
SmallVector<Metadata *> WaveSizeVals = {ConstantAsMetadata::get(
244+
ConstantInt::get(Type::getInt32Ty(Ctx), EP.WaveSizeMin))};
245+
if (IsWaveRange) {
246+
WaveSizeVals.push_back(ConstantAsMetadata::get(
247+
ConstantInt::get(Type::getInt32Ty(Ctx), EP.WaveSizeMax)));
248+
WaveSizeVals.push_back(ConstantAsMetadata::get(
249+
ConstantInt::get(Type::getInt32Ty(Ctx), EP.WaveSizePref)));
250+
}
251+
252+
MDVals.emplace_back(MDNode::get(Ctx, WaveSizeVals));
253+
}
213254
}
214255
}
256+
215257
if (MDVals.empty())
216258
return nullptr;
217259
return MDNode::get(Ctx, MDVals);
@@ -236,12 +278,11 @@ static MDTuple *constructEntryMetadata(const Function *EntryFn,
236278
return MDNode::get(Ctx, MDVals);
237279
}
238280

239-
static MDTuple *emitEntryMD(const EntryProperties &EP, MDTuple *Signatures,
240-
MDNode *MDResources,
281+
static MDTuple *emitEntryMD(Module &M, const EntryProperties &EP,
282+
MDTuple *Signatures, MDNode *MDResources,
241283
const uint64_t EntryShaderFlags,
242-
const Triple::EnvironmentType ShaderProfile) {
243-
MDTuple *Properties =
244-
getEntryPropAsMetadata(EP, EntryShaderFlags, ShaderProfile);
284+
const ModuleMetadataInfo &MMDI) {
285+
MDTuple *Properties = getEntryPropAsMetadata(M, EP, EntryShaderFlags, MMDI);
245286
return constructEntryMetadata(EP.Entry, Signatures, MDResources, Properties,
246287
EP.Entry->getContext());
247288
}
@@ -523,10 +564,8 @@ static void translateGlobalMetadata(Module &M, DXILResourceMap &DRM,
523564
Twine(Triple::getEnvironmentTypeName(MMDI.ShaderProfile) +
524565
"'"));
525566
}
526-
527-
EntryFnMDNodes.emplace_back(emitEntryMD(EntryProp, Signatures, ResourceMD,
528-
EntryShaderFlags,
529-
MMDI.ShaderProfile));
567+
EntryFnMDNodes.emplace_back(emitEntryMD(
568+
M, EntryProp, Signatures, ResourceMD, EntryShaderFlags, MMDI));
530569
}
531570

532571
NamedMDNode *EntryPointsNamedMD =
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
; RUN: split-file %s %t
2+
; RUN: not opt -S --dxil-translate-metadata %t/low-sm.ll 2>&1 | FileCheck %t/low-sm.ll
3+
; RUN: not opt -S --dxil-translate-metadata %t/low-sm-for-range.ll 2>&1 | FileCheck %t/low-sm-for-range.ll
4+
5+
; Test that wavesize metadata is only allowed on applicable shader model versions
6+
7+
;--- low-sm.ll
8+
9+
; CHECK: Shader model 6.6 or greater is required to specify the "hlsl.wavesize" function attribute
10+
11+
target triple = "dxil-unknown-shadermodel6.5-compute"
12+
13+
define void @main() #0 {
14+
entry:
15+
ret void
16+
}
17+
18+
attributes #0 = { "hlsl.wavesize"="16,0,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
19+
20+
;--- low-sm-for-range.ll
21+
22+
; CHECK: Shader model 6.8 or greater is required to specify wave size range values of the "hlsl.wavesize" function attribute
23+
24+
target triple = "dxil-unknown-shadermodel6.7-compute"
25+
26+
define void @main() #0 {
27+
entry:
28+
ret void
29+
}
30+
31+
attributes #0 = { "hlsl.wavesize"="16,32,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
; RUN: split-file %s %t
2+
; RUN: opt -S --dxil-translate-metadata %t/only.ll | FileCheck %t/only.ll
3+
; RUN: opt -S --dxil-translate-metadata %t/min.ll | FileCheck %t/min.ll
4+
; RUN: opt -S --dxil-translate-metadata %t/max.ll | FileCheck %t/max.ll
5+
; RUN: opt -S --dxil-translate-metadata %t/pref.ll | FileCheck %t/pref.ll
6+
7+
; Test that wave size/range metadata is correctly generated with the correct tag
8+
9+
;--- only.ll
10+
11+
; CHECK: !dx.entryPoints = !{![[#ENTRY:]]}
12+
; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]}
13+
; CHECK: ![[#PROPS]] = !{{{.*}}i32 11, ![[#WAVE_SIZE:]]{{.*}}}
14+
; CHECK: ![[#WAVE_SIZE]] = !{i32 16}
15+
16+
target triple = "dxil-unknown-shadermodel6.6-compute"
17+
18+
define void @main() #0 {
19+
entry:
20+
ret void
21+
}
22+
23+
attributes #0 = { "hlsl.wavesize"="16,0,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
24+
25+
;--- min.ll
26+
27+
; CHECK: !dx.entryPoints = !{![[#ENTRY:]]}
28+
; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]}
29+
; CHECK: ![[#PROPS]] = !{{{.*}}i32 23, ![[#WAVE_SIZE:]]{{.*}}}
30+
; CHECK: ![[#WAVE_SIZE]] = !{i32 16, i32 0, i32 0}
31+
32+
target triple = "dxil-unknown-shadermodel6.8-compute"
33+
34+
define void @main() #0 {
35+
entry:
36+
ret void
37+
}
38+
39+
attributes #0 = { "hlsl.wavesize"="16,0,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
40+
41+
;--- max.ll
42+
43+
; CHECK: !dx.entryPoints = !{![[#ENTRY:]]}
44+
; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]}
45+
; CHECK: ![[#PROPS]] = !{{{.*}}i32 23, ![[#WAVE_SIZE:]]{{.*}}}
46+
; CHECK: ![[#WAVE_SIZE]] = !{i32 16, i32 32, i32 0}
47+
48+
target triple = "dxil-unknown-shadermodel6.8-compute"
49+
50+
define void @main() #0 {
51+
entry:
52+
ret void
53+
}
54+
55+
attributes #0 = { "hlsl.wavesize"="16,32,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
56+
57+
;--- pref.ll
58+
59+
; CHECK: !dx.entryPoints = !{![[#ENTRY:]]}
60+
; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]}
61+
; CHECK: ![[#PROPS]] = !{{{.*}}i32 23, ![[#WAVE_SIZE:]]{{.*}}}
62+
; CHECK: ![[#WAVE_SIZE]] = !{i32 16, i32 64, i32 32}
63+
64+
target triple = "dxil-unknown-shadermodel6.8-compute"
65+
66+
define void @main() #0 {
67+
entry:
68+
ret void
69+
}
70+
71+
attributes #0 = { "hlsl.wavesize"="16,64,32" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }

0 commit comments

Comments
 (0)