Skip to content

Commit 6312d27

Browse files
authored
[DirectX] Emit hlsl.wavesize function attribute as entry property metadata (llvm#165624)
This pr adds support for emitting the `hlsl.wavesize` function attribute as an entry property metadata for a compute shader. It follows the implementation of `hlsl.numthreads`. - Collects the wave range information from the function attribute in `DXILMetadataAnalysis` - Introduce the `WaveRange` property tag - Emit a `WaveSize` or `WaveRange` metadata (depending on shader model) in `DXILTranslateMetadata` - Add tests for valid/invalid scenarios - Updates the base `PSVInfo` to reflect the min/max wave lane counts Resolves llvm#70118
1 parent dd88923 commit 6312d27

File tree

6 files changed

+206
-13
lines changed

6 files changed

+206
-13
lines changed

llvm/include/llvm/Analysis/DXILMetadataAnalysis.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ struct EntryProperties {
2727
unsigned NumThreadsX{0}; // X component
2828
unsigned NumThreadsY{0}; // Y component
2929
unsigned NumThreadsZ{0}; // Z component
30+
unsigned WaveSizeMin{0}; // Minimum component
31+
unsigned WaveSizeMax{0}; // Maximum component
32+
unsigned WaveSizePref{0}; // Preferred component
3033

3134
EntryProperties(const Function *Fn = nullptr) : Entry(Fn) {};
3235
};

llvm/lib/Analysis/DXILMetadataAnalysis.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,22 @@ static ModuleMetadataInfo collectMetadataInfo(Module &M) {
6666
Success = llvm::to_integer(NumThreadsVec[2], EFP.NumThreadsZ, 10);
6767
assert(Success && "Failed to parse Z component of numthreads");
6868
}
69+
// Get wavesize attribute value, if one exists
70+
StringRef WaveSizeStr =
71+
F.getFnAttribute("hlsl.wavesize").getValueAsString();
72+
if (!WaveSizeStr.empty()) {
73+
SmallVector<StringRef> WaveSizeVec;
74+
WaveSizeStr.split(WaveSizeVec, ',');
75+
assert(WaveSizeVec.size() == 3 && "Invalid wavesize specified");
76+
// Read in the three component values of numthreads
77+
[[maybe_unused]] bool Success =
78+
llvm::to_integer(WaveSizeVec[0], EFP.WaveSizeMin, 10);
79+
assert(Success && "Failed to parse Min component of wavesize");
80+
Success = llvm::to_integer(WaveSizeVec[1], EFP.WaveSizeMax, 10);
81+
assert(Success && "Failed to parse Max component of wavesize");
82+
Success = llvm::to_integer(WaveSizeVec[2], EFP.WaveSizePref, 10);
83+
assert(Success && "Failed to parse Preferred component of wavesize");
84+
}
6985
MMDAI.EntryPropertyVec.push_back(EFP);
7086
}
7187
return MMDAI;

llvm/lib/Target/DirectX/DXContainerGlobals.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,13 @@ void DXContainerGlobals::addPipelineStateValidationInfo(
285285
PSV.BaseData.NumThreadsX = MMI.EntryPropertyVec[0].NumThreadsX;
286286
PSV.BaseData.NumThreadsY = MMI.EntryPropertyVec[0].NumThreadsY;
287287
PSV.BaseData.NumThreadsZ = MMI.EntryPropertyVec[0].NumThreadsZ;
288+
if (MMI.EntryPropertyVec[0].WaveSizeMin) {
289+
PSV.BaseData.MinimumWaveLaneCount = MMI.EntryPropertyVec[0].WaveSizeMin;
290+
PSV.BaseData.MaximumWaveLaneCount =
291+
MMI.EntryPropertyVec[0].WaveSizeMax
292+
? MMI.EntryPropertyVec[0].WaveSizeMax
293+
: MMI.EntryPropertyVec[0].WaveSizeMin;
294+
}
288295
break;
289296
default:
290297
break;

llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ enum class EntryPropsTag {
8282
ASStateTag,
8383
WaveSize,
8484
EntryRootSig,
85+
WaveRange = 23,
8586
};
8687

8788
} // namespace
@@ -177,14 +178,15 @@ getTagValueAsMetadata(EntryPropsTag Tag, uint64_t Value, LLVMContext &Ctx) {
177178
case EntryPropsTag::ASStateTag:
178179
case EntryPropsTag::WaveSize:
179180
case EntryPropsTag::EntryRootSig:
181+
case EntryPropsTag::WaveRange:
180182
llvm_unreachable("NYI: Unhandled entry property tag");
181183
}
182184
return MDVals;
183185
}
184186

185-
static MDTuple *
186-
getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
187-
const Triple::EnvironmentType ShaderProfile) {
187+
static MDTuple *getEntryPropAsMetadata(Module &M, const EntryProperties &EP,
188+
uint64_t EntryShaderFlags,
189+
const ModuleMetadataInfo &MMDI) {
188190
SmallVector<Metadata *> MDVals;
189191
LLVMContext &Ctx = EP.Entry->getContext();
190192
if (EntryShaderFlags != 0)
@@ -195,12 +197,13 @@ getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
195197
// FIXME: support more props.
196198
// See https://github.com/llvm/llvm-project/issues/57948.
197199
// Add shader kind for lib entries.
198-
if (ShaderProfile == Triple::EnvironmentType::Library &&
200+
if (MMDI.ShaderProfile == Triple::EnvironmentType::Library &&
199201
EP.ShaderStage != Triple::EnvironmentType::Library)
200202
MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderKind,
201203
getShaderStage(EP.ShaderStage), Ctx));
202204

203205
if (EP.ShaderStage == Triple::EnvironmentType::Compute) {
206+
// Handle mandatory "hlsl.numthreads"
204207
MDVals.emplace_back(ConstantAsMetadata::get(ConstantInt::get(
205208
Type::getInt32Ty(Ctx), static_cast<int>(EntryPropsTag::NumThreads))));
206209
Metadata *NumThreadVals[] = {ConstantAsMetadata::get(ConstantInt::get(
@@ -210,8 +213,48 @@ getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
210213
ConstantAsMetadata::get(ConstantInt::get(
211214
Type::getInt32Ty(Ctx), EP.NumThreadsZ))};
212215
MDVals.emplace_back(MDNode::get(Ctx, NumThreadVals));
216+
217+
// Handle optional "hlsl.wavesize". The fields are optionally represented
218+
// if they are non-zero.
219+
if (EP.WaveSizeMin != 0) {
220+
bool IsWaveRange = VersionTuple(6, 8) <= MMDI.ShaderModelVersion;
221+
bool IsWaveSize =
222+
!IsWaveRange && VersionTuple(6, 6) <= MMDI.ShaderModelVersion;
223+
224+
if (!IsWaveRange && !IsWaveSize) {
225+
reportError(M, "Shader model 6.6 or greater is required to specify "
226+
"the \"hlsl.wavesize\" function attribute");
227+
return nullptr;
228+
}
229+
230+
// A range is being specified if EP.WaveSizeMax != 0
231+
if (EP.WaveSizeMax && !IsWaveRange) {
232+
reportError(
233+
M, "Shader model 6.8 or greater is required to specify "
234+
"wave size range values of the \"hlsl.wavesize\" function "
235+
"attribute");
236+
return nullptr;
237+
}
238+
239+
EntryPropsTag Tag =
240+
IsWaveSize ? EntryPropsTag::WaveSize : EntryPropsTag::WaveRange;
241+
MDVals.emplace_back(ConstantAsMetadata::get(
242+
ConstantInt::get(Type::getInt32Ty(Ctx), static_cast<int>(Tag))));
243+
244+
SmallVector<Metadata *> WaveSizeVals = {ConstantAsMetadata::get(
245+
ConstantInt::get(Type::getInt32Ty(Ctx), EP.WaveSizeMin))};
246+
if (IsWaveRange) {
247+
WaveSizeVals.push_back(ConstantAsMetadata::get(
248+
ConstantInt::get(Type::getInt32Ty(Ctx), EP.WaveSizeMax)));
249+
WaveSizeVals.push_back(ConstantAsMetadata::get(
250+
ConstantInt::get(Type::getInt32Ty(Ctx), EP.WaveSizePref)));
251+
}
252+
253+
MDVals.emplace_back(MDNode::get(Ctx, WaveSizeVals));
254+
}
213255
}
214256
}
257+
215258
if (MDVals.empty())
216259
return nullptr;
217260
return MDNode::get(Ctx, MDVals);
@@ -236,12 +279,11 @@ static MDTuple *constructEntryMetadata(const Function *EntryFn,
236279
return MDNode::get(Ctx, MDVals);
237280
}
238281

239-
static MDTuple *emitEntryMD(const EntryProperties &EP, MDTuple *Signatures,
240-
MDNode *MDResources,
282+
static MDTuple *emitEntryMD(Module &M, const EntryProperties &EP,
283+
MDTuple *Signatures, MDNode *MDResources,
241284
const uint64_t EntryShaderFlags,
242-
const Triple::EnvironmentType ShaderProfile) {
243-
MDTuple *Properties =
244-
getEntryPropAsMetadata(EP, EntryShaderFlags, ShaderProfile);
285+
const ModuleMetadataInfo &MMDI) {
286+
MDTuple *Properties = getEntryPropAsMetadata(M, EP, EntryShaderFlags, MMDI);
245287
return constructEntryMetadata(EP.Entry, Signatures, MDResources, Properties,
246288
EP.Entry->getContext());
247289
}
@@ -523,10 +565,8 @@ static void translateGlobalMetadata(Module &M, DXILResourceMap &DRM,
523565
Twine(Triple::getEnvironmentTypeName(MMDI.ShaderProfile) +
524566
"'"));
525567
}
526-
527-
EntryFnMDNodes.emplace_back(emitEntryMD(EntryProp, Signatures, ResourceMD,
528-
EntryShaderFlags,
529-
MMDI.ShaderProfile));
568+
EntryFnMDNodes.emplace_back(emitEntryMD(
569+
M, EntryProp, Signatures, ResourceMD, EntryShaderFlags, MMDI));
530570
}
531571

532572
NamedMDNode *EntryPointsNamedMD =
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
; RUN: split-file %s %t
2+
; RUN: not opt -S --dxil-translate-metadata %t/low-sm.ll 2>&1 | FileCheck %t/low-sm.ll
3+
; RUN: not opt -S --dxil-translate-metadata %t/low-sm-for-range.ll 2>&1 | FileCheck %t/low-sm-for-range.ll
4+
5+
; Test that wavesize metadata is only allowed on applicable shader model versions
6+
7+
;--- low-sm.ll
8+
9+
; CHECK: Shader model 6.6 or greater is required to specify the "hlsl.wavesize" function attribute
10+
11+
target triple = "dxil-unknown-shadermodel6.5-compute"
12+
13+
define void @main() #0 {
14+
entry:
15+
ret void
16+
}
17+
18+
attributes #0 = { "hlsl.wavesize"="16,0,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
19+
20+
;--- low-sm-for-range.ll
21+
22+
; CHECK: Shader model 6.8 or greater is required to specify wave size range values of the "hlsl.wavesize" function attribute
23+
24+
target triple = "dxil-unknown-shadermodel6.7-compute"
25+
26+
define void @main() #0 {
27+
entry:
28+
ret void
29+
}
30+
31+
attributes #0 = { "hlsl.wavesize"="16,32,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
; RUN: split-file %s %t
2+
; RUN: opt -S --dxil-translate-metadata %t/only.ll | FileCheck %t/only.ll
3+
; RUN: opt -S --dxil-translate-metadata %t/min.ll | FileCheck %t/min.ll
4+
; RUN: opt -S --dxil-translate-metadata %t/max.ll | FileCheck %t/max.ll
5+
; RUN: opt -S --dxil-translate-metadata %t/pref.ll | FileCheck %t/pref.ll
6+
7+
; RUN: llc --filetype=obj %t/only.ll -o - | obj2yaml | FileCheck %t/only.ll --check-prefix=OBJ
8+
; RUN: llc --filetype=obj %t/min.ll -o - | obj2yaml | FileCheck %t/min.ll --check-prefix=OBJ
9+
; RUN: llc --filetype=obj %t/max.ll -o - | obj2yaml | FileCheck %t/max.ll --check-prefix=OBJ
10+
; RUN: llc --filetype=obj %t/pref.ll -o - | obj2yaml | FileCheck %t/pref.ll --check-prefix=OBJ
11+
12+
; Test that wave size/range metadata is correctly generated with the correct tag
13+
14+
;--- only.ll
15+
16+
; CHECK: !dx.entryPoints = !{![[#ENTRY:]]}
17+
; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]}
18+
; CHECK: ![[#PROPS]] = !{{{.*}}i32 11, ![[#WAVE_SIZE:]]{{.*}}}
19+
; CHECK: ![[#WAVE_SIZE]] = !{i32 16}
20+
21+
; OBJ: - Name: PSV0
22+
; OBJ: PSVInfo:
23+
; OBJ: MinimumWaveLaneCount: 16
24+
; OBJ: MaximumWaveLaneCount: 16
25+
26+
target triple = "dxil-unknown-shadermodel6.6-compute"
27+
28+
define void @main() #0 {
29+
entry:
30+
ret void
31+
}
32+
33+
attributes #0 = { "hlsl.wavesize"="16,0,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
34+
35+
;--- min.ll
36+
37+
; CHECK: !dx.entryPoints = !{![[#ENTRY:]]}
38+
; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]}
39+
; CHECK: ![[#PROPS]] = !{{{.*}}i32 23, ![[#WAVE_SIZE:]]{{.*}}}
40+
; CHECK: ![[#WAVE_SIZE]] = !{i32 16, i32 0, i32 0}
41+
42+
; OBJ: - Name: PSV0
43+
; OBJ: PSVInfo:
44+
; OBJ: MinimumWaveLaneCount: 16
45+
; OBJ: MaximumWaveLaneCount: 16
46+
47+
target triple = "dxil-unknown-shadermodel6.8-compute"
48+
49+
define void @main() #0 {
50+
entry:
51+
ret void
52+
}
53+
54+
attributes #0 = { "hlsl.wavesize"="16,0,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
55+
56+
;--- max.ll
57+
58+
; CHECK: !dx.entryPoints = !{![[#ENTRY:]]}
59+
; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]}
60+
; CHECK: ![[#PROPS]] = !{{{.*}}i32 23, ![[#WAVE_SIZE:]]{{.*}}}
61+
; CHECK: ![[#WAVE_SIZE]] = !{i32 16, i32 32, i32 0}
62+
63+
; OBJ: - Name: PSV0
64+
; OBJ: PSVInfo:
65+
; OBJ: MinimumWaveLaneCount: 16
66+
; OBJ: MaximumWaveLaneCount: 32
67+
68+
target triple = "dxil-unknown-shadermodel6.8-compute"
69+
70+
define void @main() #0 {
71+
entry:
72+
ret void
73+
}
74+
75+
attributes #0 = { "hlsl.wavesize"="16,32,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
76+
77+
;--- pref.ll
78+
79+
; CHECK: !dx.entryPoints = !{![[#ENTRY:]]}
80+
; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]}
81+
; CHECK: ![[#PROPS]] = !{{{.*}}i32 23, ![[#WAVE_SIZE:]]{{.*}}}
82+
; CHECK: ![[#WAVE_SIZE]] = !{i32 16, i32 64, i32 32}
83+
84+
; OBJ: - Name: PSV0
85+
; OBJ: PSVInfo:
86+
; OBJ: MinimumWaveLaneCount: 16
87+
; OBJ: MaximumWaveLaneCount: 64
88+
89+
target triple = "dxil-unknown-shadermodel6.8-compute"
90+
91+
define void @main() #0 {
92+
entry:
93+
ret void
94+
}
95+
96+
attributes #0 = { "hlsl.wavesize"="16,64,32" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }

0 commit comments

Comments
 (0)