-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[DirectX] Emit hlsl.wavesize function attribute as entry property metadata
#165624
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WaveSize function attribute metadatahlsl.wavesize function attribute as entry property metadata
|
@llvm/pr-subscribers-llvm-analysis @llvm/pr-subscribers-backend-directx Author: Finn Plummer (inbelic) ChangesThis pr adds support for emitting the It follows the implementation of
Resolves #70118 Full diff: https://github.com/llvm/llvm-project/pull/165624.diff 5 Files Affected:
diff --git a/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h b/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h
index cb535ac14f1c6..a1b030c157eae 100644
--- a/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h
+++ b/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h
@@ -27,6 +27,9 @@ struct EntryProperties {
unsigned NumThreadsX{0}; // X component
unsigned NumThreadsY{0}; // Y component
unsigned NumThreadsZ{0}; // Z component
+ unsigned WaveSizeMin{0}; // Minimum component
+ unsigned WaveSizeMax{0}; // Maximum component
+ unsigned WaveSizePref{0}; // Preferred component
EntryProperties(const Function *Fn = nullptr) : Entry(Fn) {};
};
diff --git a/llvm/lib/Analysis/DXILMetadataAnalysis.cpp b/llvm/lib/Analysis/DXILMetadataAnalysis.cpp
index 23f1aa82ae8a3..bd77cba385667 100644
--- a/llvm/lib/Analysis/DXILMetadataAnalysis.cpp
+++ b/llvm/lib/Analysis/DXILMetadataAnalysis.cpp
@@ -66,6 +66,22 @@ static ModuleMetadataInfo collectMetadataInfo(Module &M) {
Success = llvm::to_integer(NumThreadsVec[2], EFP.NumThreadsZ, 10);
assert(Success && "Failed to parse Z component of numthreads");
}
+ // Get wavesize attribute value, if one exists
+ StringRef WaveSizeStr =
+ F.getFnAttribute("hlsl.wavesize").getValueAsString();
+ if (!WaveSizeStr.empty()) {
+ SmallVector<StringRef> WaveSizeVec;
+ WaveSizeStr.split(WaveSizeVec, ',');
+ assert(WaveSizeVec.size() == 3 && "Invalid wavesize specified");
+ // Read in the three component values of numthreads
+ [[maybe_unused]] bool Success =
+ llvm::to_integer(WaveSizeVec[0], EFP.WaveSizeMin, 10);
+ assert(Success && "Failed to parse Min component of wavesize");
+ Success = llvm::to_integer(WaveSizeVec[1], EFP.WaveSizeMax, 10);
+ assert(Success && "Failed to parse Max component of wavesize");
+ Success = llvm::to_integer(WaveSizeVec[2], EFP.WaveSizePref, 10);
+ assert(Success && "Failed to parse Preferred component of wavesize");
+ }
MMDAI.EntryPropertyVec.push_back(EFP);
}
return MMDAI;
diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index cf8b833b3e42e..682847a94c6fb 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -82,6 +82,7 @@ enum class EntryPropsTag {
ASStateTag,
WaveSize,
EntryRootSig,
+ WaveRange = 23,
};
} // namespace
@@ -177,30 +178,32 @@ getTagValueAsMetadata(EntryPropsTag Tag, uint64_t Value, LLVMContext &Ctx) {
case EntryPropsTag::ASStateTag:
case EntryPropsTag::WaveSize:
case EntryPropsTag::EntryRootSig:
+ case EntryPropsTag::WaveRange:
llvm_unreachable("NYI: Unhandled entry property tag");
}
return MDVals;
}
-static MDTuple *
-getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
- const Triple::EnvironmentType ShaderProfile) {
+static MDTuple *getEntryPropAsMetadata(Module &M, const EntryProperties &EP,
+ uint64_t EntryShaderFlags,
+ const ModuleMetadataInfo &MMDI) {
SmallVector<Metadata *> MDVals;
LLVMContext &Ctx = EP.Entry->getContext();
if (EntryShaderFlags != 0)
MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderFlags,
- EntryShaderFlags, Ctx));
+ MMDI.ShaderProfile, Ctx));
if (EP.Entry != nullptr) {
// FIXME: support more props.
// See https://github.com/llvm/llvm-project/issues/57948.
// Add shader kind for lib entries.
- if (ShaderProfile == Triple::EnvironmentType::Library &&
+ if (MMDI.ShaderProfile == Triple::EnvironmentType::Library &&
EP.ShaderStage != Triple::EnvironmentType::Library)
MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderKind,
getShaderStage(EP.ShaderStage), Ctx));
if (EP.ShaderStage == Triple::EnvironmentType::Compute) {
+ // Handle mandatory "hlsl.numthreads"
MDVals.emplace_back(ConstantAsMetadata::get(ConstantInt::get(
Type::getInt32Ty(Ctx), static_cast<int>(EntryPropsTag::NumThreads))));
Metadata *NumThreadVals[] = {ConstantAsMetadata::get(ConstantInt::get(
@@ -210,8 +213,47 @@ getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
ConstantAsMetadata::get(ConstantInt::get(
Type::getInt32Ty(Ctx), EP.NumThreadsZ))};
MDVals.emplace_back(MDNode::get(Ctx, NumThreadVals));
+
+ // Handle optional "hlsl.wavesize". The fields are optionally represented
+ // if they are non-zero.
+ if (EP.WaveSizeMin != 0) {
+ bool IsWaveRange = VersionTuple(6, 8) <= MMDI.ShaderModelVersion;
+ bool IsWaveSize =
+ !IsWaveRange && VersionTuple(6, 6) <= MMDI.ShaderModelVersion;
+
+ if (!IsWaveRange && !IsWaveSize) {
+ reportError(M, "Shader model 6.6 or greater is required to specify "
+ "the \"hlsl.wavesize\" function attribute");
+ return nullptr;
+ }
+
+ if (EP.WaveSizeMax && !IsWaveRange) {
+ reportError(
+ M, "Shader model 6.8 or greater is required to specify "
+ "wave size range values of the \"hlsl.wavesize\" function "
+ "attribute");
+ return nullptr;
+ }
+
+ EntryPropsTag Tag =
+ IsWaveSize ? EntryPropsTag::WaveSize : EntryPropsTag::WaveRange;
+ MDVals.emplace_back(ConstantAsMetadata::get(
+ ConstantInt::get(Type::getInt32Ty(Ctx), static_cast<int>(Tag))));
+
+ SmallVector<Metadata *> WaveSizeVals = {ConstantAsMetadata::get(
+ ConstantInt::get(Type::getInt32Ty(Ctx), EP.WaveSizeMin))};
+ if (IsWaveRange) {
+ WaveSizeVals.push_back(ConstantAsMetadata::get(
+ ConstantInt::get(Type::getInt32Ty(Ctx), EP.WaveSizeMax)));
+ WaveSizeVals.push_back(ConstantAsMetadata::get(
+ ConstantInt::get(Type::getInt32Ty(Ctx), EP.WaveSizePref)));
+ }
+
+ MDVals.emplace_back(MDNode::get(Ctx, WaveSizeVals));
+ }
}
}
+
if (MDVals.empty())
return nullptr;
return MDNode::get(Ctx, MDVals);
@@ -236,12 +278,11 @@ static MDTuple *constructEntryMetadata(const Function *EntryFn,
return MDNode::get(Ctx, MDVals);
}
-static MDTuple *emitEntryMD(const EntryProperties &EP, MDTuple *Signatures,
- MDNode *MDResources,
+static MDTuple *emitEntryMD(Module &M, const EntryProperties &EP,
+ MDTuple *Signatures, MDNode *MDResources,
const uint64_t EntryShaderFlags,
- const Triple::EnvironmentType ShaderProfile) {
- MDTuple *Properties =
- getEntryPropAsMetadata(EP, EntryShaderFlags, ShaderProfile);
+ const ModuleMetadataInfo &MMDI) {
+ MDTuple *Properties = getEntryPropAsMetadata(M, EP, EntryShaderFlags, MMDI);
return constructEntryMetadata(EP.Entry, Signatures, MDResources, Properties,
EP.Entry->getContext());
}
@@ -523,10 +564,8 @@ static void translateGlobalMetadata(Module &M, DXILResourceMap &DRM,
Twine(Triple::getEnvironmentTypeName(MMDI.ShaderProfile) +
"'"));
}
-
- EntryFnMDNodes.emplace_back(emitEntryMD(EntryProp, Signatures, ResourceMD,
- EntryShaderFlags,
- MMDI.ShaderProfile));
+ EntryFnMDNodes.emplace_back(emitEntryMD(
+ M, EntryProp, Signatures, ResourceMD, EntryShaderFlags, MMDI));
}
NamedMDNode *EntryPointsNamedMD =
diff --git a/llvm/test/CodeGen/DirectX/wavesize-md-errs.ll b/llvm/test/CodeGen/DirectX/wavesize-md-errs.ll
new file mode 100644
index 0000000000000..9016c5d7e8d44
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/wavesize-md-errs.ll
@@ -0,0 +1,31 @@
+; RUN: split-file %s %t
+; RUN: not opt -S --dxil-translate-metadata %t/low-sm.ll 2>&1 | FileCheck %t/low-sm.ll
+; RUN: not opt -S --dxil-translate-metadata %t/low-sm-for-range.ll 2>&1 | FileCheck %t/low-sm-for-range.ll
+
+; Test that wavesize metadata is only allowed on applicable shader model versions
+
+;--- low-sm.ll
+
+; CHECK: Shader model 6.6 or greater is required to specify the "hlsl.wavesize" function attribute
+
+target triple = "dxil-unknown-shadermodel6.5-compute"
+
+define void @main() #0 {
+entry:
+ ret void
+}
+
+attributes #0 = { "hlsl.wavesize"="16,0,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
+
+;--- low-sm-for-range.ll
+
+; CHECK: Shader model 6.8 or greater is required to specify wave size range values of the "hlsl.wavesize" function attribute
+
+target triple = "dxil-unknown-shadermodel6.7-compute"
+
+define void @main() #0 {
+entry:
+ ret void
+}
+
+attributes #0 = { "hlsl.wavesize"="16,32,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
diff --git a/llvm/test/CodeGen/DirectX/wavesize-md-valid.ll b/llvm/test/CodeGen/DirectX/wavesize-md-valid.ll
new file mode 100644
index 0000000000000..63e8a59eb2648
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/wavesize-md-valid.ll
@@ -0,0 +1,71 @@
+; RUN: split-file %s %t
+; RUN: opt -S --dxil-translate-metadata %t/only.ll | FileCheck %t/only.ll
+; RUN: opt -S --dxil-translate-metadata %t/min.ll | FileCheck %t/min.ll
+; RUN: opt -S --dxil-translate-metadata %t/max.ll | FileCheck %t/max.ll
+; RUN: opt -S --dxil-translate-metadata %t/pref.ll | FileCheck %t/pref.ll
+
+; Test that wave size/range metadata is correctly generated with the correct tag
+
+;--- only.ll
+
+; CHECK: !dx.entryPoints = !{![[#ENTRY:]]}
+; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]}
+; CHECK: ![[#PROPS]] = !{{{.*}}i32 11, ![[#WAVE_SIZE:]]{{.*}}}
+; CHECK: ![[#WAVE_SIZE]] = !{i32 16}
+
+target triple = "dxil-unknown-shadermodel6.6-compute"
+
+define void @main() #0 {
+entry:
+ ret void
+}
+
+attributes #0 = { "hlsl.wavesize"="16,0,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
+
+;--- min.ll
+
+; CHECK: !dx.entryPoints = !{![[#ENTRY:]]}
+; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]}
+; CHECK: ![[#PROPS]] = !{{{.*}}i32 23, ![[#WAVE_SIZE:]]{{.*}}}
+; CHECK: ![[#WAVE_SIZE]] = !{i32 16, i32 0, i32 0}
+
+target triple = "dxil-unknown-shadermodel6.8-compute"
+
+define void @main() #0 {
+entry:
+ ret void
+}
+
+attributes #0 = { "hlsl.wavesize"="16,0,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
+
+;--- max.ll
+
+; CHECK: !dx.entryPoints = !{![[#ENTRY:]]}
+; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]}
+; CHECK: ![[#PROPS]] = !{{{.*}}i32 23, ![[#WAVE_SIZE:]]{{.*}}}
+; CHECK: ![[#WAVE_SIZE]] = !{i32 16, i32 32, i32 0}
+
+target triple = "dxil-unknown-shadermodel6.8-compute"
+
+define void @main() #0 {
+entry:
+ ret void
+}
+
+attributes #0 = { "hlsl.wavesize"="16,32,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
+
+;--- pref.ll
+
+; CHECK: !dx.entryPoints = !{![[#ENTRY:]]}
+; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]}
+; CHECK: ![[#PROPS]] = !{{{.*}}i32 23, ![[#WAVE_SIZE:]]{{.*}}}
+; CHECK: ![[#WAVE_SIZE]] = !{i32 16, i32 64, i32 32}
+
+target triple = "dxil-unknown-shadermodel6.8-compute"
+
+define void @main() #0 {
+entry:
+ ret void
+}
+
+attributes #0 = { "hlsl.wavesize"="16,64,32" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
|
Hold on merging. I'm getting several DXIL validation failures with this PR |
Validation failures encountered with this PR.
This comment was marked as outdated.
This comment was marked as outdated.
|
This test case fails DXIL validation with this PR: [numthreads(16, 1, 1)][WaveSize(16)] void CSMain() {} |
|
Looks like there is some custom handling that I missed from here: https://github.com/microsoft/DirectXShaderCompiler/blob/fa3fe957ea9656df75ce683c47fb2892151e4732/lib/DxilContainer/DxilContainerAssembler.cpp#L1536. Looking to update it now. Thanks for catching that. |
This pr adds support for emitting the
hlsl.wavesizefunction attribute as an entry property metadata for a compute shader.It follows the implementation of
hlsl.numthreads.DXILMetadataAnalysisWaveRangeproperty tagWaveSizeorWaveRangemetadata (depending on shader model) inDXILTranslateMetadataPSVInfoto reflect the min/max wave lane countsResolves #70118