From 793019a5052f27d770f9874f3cea311c9968339d Mon Sep 17 00:00:00 2001 From: Finn Plummer Date: Wed, 29 Oct 2025 13:45:14 -0700 Subject: [PATCH 1/3] [DirectX] Emit `WaveSize` function attribute metadata --- .../llvm/Analysis/DXILMetadataAnalysis.h | 3 + llvm/lib/Analysis/DXILMetadataAnalysis.cpp | 16 +++++ .../Target/DirectX/DXILTranslateMetadata.cpp | 67 +++++++++++++---- llvm/test/CodeGen/DirectX/wavesize-md-errs.ll | 31 ++++++++ .../test/CodeGen/DirectX/wavesize-md-valid.ll | 71 +++++++++++++++++++ 5 files changed, 174 insertions(+), 14 deletions(-) create mode 100644 llvm/test/CodeGen/DirectX/wavesize-md-errs.ll create mode 100644 llvm/test/CodeGen/DirectX/wavesize-md-valid.ll diff --git a/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h b/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h index cb535ac14f1c6..a1b030c157eae 100644 --- a/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h +++ b/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h @@ -27,6 +27,9 @@ struct EntryProperties { unsigned NumThreadsX{0}; // X component unsigned NumThreadsY{0}; // Y component unsigned NumThreadsZ{0}; // Z component + unsigned WaveSizeMin{0}; // Minimum component + unsigned WaveSizeMax{0}; // Maximum component + unsigned WaveSizePref{0}; // Preferred component EntryProperties(const Function *Fn = nullptr) : Entry(Fn) {}; }; diff --git a/llvm/lib/Analysis/DXILMetadataAnalysis.cpp b/llvm/lib/Analysis/DXILMetadataAnalysis.cpp index 23f1aa82ae8a3..bd77cba385667 100644 --- a/llvm/lib/Analysis/DXILMetadataAnalysis.cpp +++ b/llvm/lib/Analysis/DXILMetadataAnalysis.cpp @@ -66,6 +66,22 @@ static ModuleMetadataInfo collectMetadataInfo(Module &M) { Success = llvm::to_integer(NumThreadsVec[2], EFP.NumThreadsZ, 10); assert(Success && "Failed to parse Z component of numthreads"); } + // Get wavesize attribute value, if one exists + StringRef WaveSizeStr = + F.getFnAttribute("hlsl.wavesize").getValueAsString(); + if (!WaveSizeStr.empty()) { + SmallVector WaveSizeVec; + WaveSizeStr.split(WaveSizeVec, ','); + assert(WaveSizeVec.size() == 3 && "Invalid wavesize specified"); + // Read in the three component values of numthreads + [[maybe_unused]] bool Success = + llvm::to_integer(WaveSizeVec[0], EFP.WaveSizeMin, 10); + assert(Success && "Failed to parse Min component of wavesize"); + Success = llvm::to_integer(WaveSizeVec[1], EFP.WaveSizeMax, 10); + assert(Success && "Failed to parse Max component of wavesize"); + Success = llvm::to_integer(WaveSizeVec[2], EFP.WaveSizePref, 10); + assert(Success && "Failed to parse Preferred component of wavesize"); + } MMDAI.EntryPropertyVec.push_back(EFP); } return MMDAI; diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp index cf8b833b3e42e..682847a94c6fb 100644 --- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp +++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp @@ -82,6 +82,7 @@ enum class EntryPropsTag { ASStateTag, WaveSize, EntryRootSig, + WaveRange = 23, }; } // namespace @@ -177,30 +178,32 @@ getTagValueAsMetadata(EntryPropsTag Tag, uint64_t Value, LLVMContext &Ctx) { case EntryPropsTag::ASStateTag: case EntryPropsTag::WaveSize: case EntryPropsTag::EntryRootSig: + case EntryPropsTag::WaveRange: llvm_unreachable("NYI: Unhandled entry property tag"); } return MDVals; } -static MDTuple * -getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags, - const Triple::EnvironmentType ShaderProfile) { +static MDTuple *getEntryPropAsMetadata(Module &M, const EntryProperties &EP, + uint64_t EntryShaderFlags, + const ModuleMetadataInfo &MMDI) { SmallVector MDVals; LLVMContext &Ctx = EP.Entry->getContext(); if (EntryShaderFlags != 0) MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderFlags, - EntryShaderFlags, Ctx)); + MMDI.ShaderProfile, Ctx)); if (EP.Entry != nullptr) { // FIXME: support more props. // See https://github.com/llvm/llvm-project/issues/57948. // Add shader kind for lib entries. - if (ShaderProfile == Triple::EnvironmentType::Library && + if (MMDI.ShaderProfile == Triple::EnvironmentType::Library && EP.ShaderStage != Triple::EnvironmentType::Library) MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderKind, getShaderStage(EP.ShaderStage), Ctx)); if (EP.ShaderStage == Triple::EnvironmentType::Compute) { + // Handle mandatory "hlsl.numthreads" MDVals.emplace_back(ConstantAsMetadata::get(ConstantInt::get( Type::getInt32Ty(Ctx), static_cast(EntryPropsTag::NumThreads)))); Metadata *NumThreadVals[] = {ConstantAsMetadata::get(ConstantInt::get( @@ -210,8 +213,47 @@ getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags, ConstantAsMetadata::get(ConstantInt::get( Type::getInt32Ty(Ctx), EP.NumThreadsZ))}; MDVals.emplace_back(MDNode::get(Ctx, NumThreadVals)); + + // Handle optional "hlsl.wavesize". The fields are optionally represented + // if they are non-zero. + if (EP.WaveSizeMin != 0) { + bool IsWaveRange = VersionTuple(6, 8) <= MMDI.ShaderModelVersion; + bool IsWaveSize = + !IsWaveRange && VersionTuple(6, 6) <= MMDI.ShaderModelVersion; + + if (!IsWaveRange && !IsWaveSize) { + reportError(M, "Shader model 6.6 or greater is required to specify " + "the \"hlsl.wavesize\" function attribute"); + return nullptr; + } + + if (EP.WaveSizeMax && !IsWaveRange) { + reportError( + M, "Shader model 6.8 or greater is required to specify " + "wave size range values of the \"hlsl.wavesize\" function " + "attribute"); + return nullptr; + } + + EntryPropsTag Tag = + IsWaveSize ? EntryPropsTag::WaveSize : EntryPropsTag::WaveRange; + MDVals.emplace_back(ConstantAsMetadata::get( + ConstantInt::get(Type::getInt32Ty(Ctx), static_cast(Tag)))); + + SmallVector WaveSizeVals = {ConstantAsMetadata::get( + ConstantInt::get(Type::getInt32Ty(Ctx), EP.WaveSizeMin))}; + if (IsWaveRange) { + WaveSizeVals.push_back(ConstantAsMetadata::get( + ConstantInt::get(Type::getInt32Ty(Ctx), EP.WaveSizeMax))); + WaveSizeVals.push_back(ConstantAsMetadata::get( + ConstantInt::get(Type::getInt32Ty(Ctx), EP.WaveSizePref))); + } + + MDVals.emplace_back(MDNode::get(Ctx, WaveSizeVals)); + } } } + if (MDVals.empty()) return nullptr; return MDNode::get(Ctx, MDVals); @@ -236,12 +278,11 @@ static MDTuple *constructEntryMetadata(const Function *EntryFn, return MDNode::get(Ctx, MDVals); } -static MDTuple *emitEntryMD(const EntryProperties &EP, MDTuple *Signatures, - MDNode *MDResources, +static MDTuple *emitEntryMD(Module &M, const EntryProperties &EP, + MDTuple *Signatures, MDNode *MDResources, const uint64_t EntryShaderFlags, - const Triple::EnvironmentType ShaderProfile) { - MDTuple *Properties = - getEntryPropAsMetadata(EP, EntryShaderFlags, ShaderProfile); + const ModuleMetadataInfo &MMDI) { + MDTuple *Properties = getEntryPropAsMetadata(M, EP, EntryShaderFlags, MMDI); return constructEntryMetadata(EP.Entry, Signatures, MDResources, Properties, EP.Entry->getContext()); } @@ -523,10 +564,8 @@ static void translateGlobalMetadata(Module &M, DXILResourceMap &DRM, Twine(Triple::getEnvironmentTypeName(MMDI.ShaderProfile) + "'")); } - - EntryFnMDNodes.emplace_back(emitEntryMD(EntryProp, Signatures, ResourceMD, - EntryShaderFlags, - MMDI.ShaderProfile)); + EntryFnMDNodes.emplace_back(emitEntryMD( + M, EntryProp, Signatures, ResourceMD, EntryShaderFlags, MMDI)); } NamedMDNode *EntryPointsNamedMD = diff --git a/llvm/test/CodeGen/DirectX/wavesize-md-errs.ll b/llvm/test/CodeGen/DirectX/wavesize-md-errs.ll new file mode 100644 index 0000000000000..9016c5d7e8d44 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/wavesize-md-errs.ll @@ -0,0 +1,31 @@ +; RUN: split-file %s %t +; RUN: not opt -S --dxil-translate-metadata %t/low-sm.ll 2>&1 | FileCheck %t/low-sm.ll +; RUN: not opt -S --dxil-translate-metadata %t/low-sm-for-range.ll 2>&1 | FileCheck %t/low-sm-for-range.ll + +; Test that wavesize metadata is only allowed on applicable shader model versions + +;--- low-sm.ll + +; CHECK: Shader model 6.6 or greater is required to specify the "hlsl.wavesize" function attribute + +target triple = "dxil-unknown-shadermodel6.5-compute" + +define void @main() #0 { +entry: + ret void +} + +attributes #0 = { "hlsl.wavesize"="16,0,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" } + +;--- low-sm-for-range.ll + +; CHECK: Shader model 6.8 or greater is required to specify wave size range values of the "hlsl.wavesize" function attribute + +target triple = "dxil-unknown-shadermodel6.7-compute" + +define void @main() #0 { +entry: + ret void +} + +attributes #0 = { "hlsl.wavesize"="16,32,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" } diff --git a/llvm/test/CodeGen/DirectX/wavesize-md-valid.ll b/llvm/test/CodeGen/DirectX/wavesize-md-valid.ll new file mode 100644 index 0000000000000..63e8a59eb2648 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/wavesize-md-valid.ll @@ -0,0 +1,71 @@ +; RUN: split-file %s %t +; RUN: opt -S --dxil-translate-metadata %t/only.ll | FileCheck %t/only.ll +; RUN: opt -S --dxil-translate-metadata %t/min.ll | FileCheck %t/min.ll +; RUN: opt -S --dxil-translate-metadata %t/max.ll | FileCheck %t/max.ll +; RUN: opt -S --dxil-translate-metadata %t/pref.ll | FileCheck %t/pref.ll + +; Test that wave size/range metadata is correctly generated with the correct tag + +;--- only.ll + +; CHECK: !dx.entryPoints = !{![[#ENTRY:]]} +; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]} +; CHECK: ![[#PROPS]] = !{{{.*}}i32 11, ![[#WAVE_SIZE:]]{{.*}}} +; CHECK: ![[#WAVE_SIZE]] = !{i32 16} + +target triple = "dxil-unknown-shadermodel6.6-compute" + +define void @main() #0 { +entry: + ret void +} + +attributes #0 = { "hlsl.wavesize"="16,0,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" } + +;--- min.ll + +; CHECK: !dx.entryPoints = !{![[#ENTRY:]]} +; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]} +; CHECK: ![[#PROPS]] = !{{{.*}}i32 23, ![[#WAVE_SIZE:]]{{.*}}} +; CHECK: ![[#WAVE_SIZE]] = !{i32 16, i32 0, i32 0} + +target triple = "dxil-unknown-shadermodel6.8-compute" + +define void @main() #0 { +entry: + ret void +} + +attributes #0 = { "hlsl.wavesize"="16,0,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" } + +;--- max.ll + +; CHECK: !dx.entryPoints = !{![[#ENTRY:]]} +; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]} +; CHECK: ![[#PROPS]] = !{{{.*}}i32 23, ![[#WAVE_SIZE:]]{{.*}}} +; CHECK: ![[#WAVE_SIZE]] = !{i32 16, i32 32, i32 0} + +target triple = "dxil-unknown-shadermodel6.8-compute" + +define void @main() #0 { +entry: + ret void +} + +attributes #0 = { "hlsl.wavesize"="16,32,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" } + +;--- pref.ll + +; CHECK: !dx.entryPoints = !{![[#ENTRY:]]} +; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]} +; CHECK: ![[#PROPS]] = !{{{.*}}i32 23, ![[#WAVE_SIZE:]]{{.*}}} +; CHECK: ![[#WAVE_SIZE]] = !{i32 16, i32 64, i32 32} + +target triple = "dxil-unknown-shadermodel6.8-compute" + +define void @main() #0 { +entry: + ret void +} + +attributes #0 = { "hlsl.wavesize"="16,64,32" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" } From ff1838eed00dccf45726eaa45b1de14277e951e3 Mon Sep 17 00:00:00 2001 From: Finn Plummer Date: Wed, 29 Oct 2025 14:02:42 -0700 Subject: [PATCH 2/3] small corrections --- llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp index 682847a94c6fb..e1a472fe57642 100644 --- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp +++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp @@ -191,7 +191,7 @@ static MDTuple *getEntryPropAsMetadata(Module &M, const EntryProperties &EP, LLVMContext &Ctx = EP.Entry->getContext(); if (EntryShaderFlags != 0) MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderFlags, - MMDI.ShaderProfile, Ctx)); + EntryShaderFlags, Ctx)); if (EP.Entry != nullptr) { // FIXME: support more props. @@ -227,6 +227,7 @@ static MDTuple *getEntryPropAsMetadata(Module &M, const EntryProperties &EP, return nullptr; } + // A range is being specified if EP.WaveSizeMax != 0 if (EP.WaveSizeMax && !IsWaveRange) { reportError( M, "Shader model 6.8 or greater is required to specify " From c3ac81cbaf03be52bda859e90640a962d597e59f Mon Sep 17 00:00:00 2001 From: Finn Plummer Date: Mon, 3 Nov 2025 12:58:48 -0800 Subject: [PATCH 3/3] update corresponding PSVInfo --- .../lib/Target/DirectX/DXContainerGlobals.cpp | 7 ++++++ .../test/CodeGen/DirectX/wavesize-md-valid.ll | 25 +++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp index 8ace2d2777c74..f22ee6dd17ac1 100644 --- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp +++ b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp @@ -284,6 +284,13 @@ void DXContainerGlobals::addPipelineStateValidationInfo( PSV.BaseData.NumThreadsX = MMI.EntryPropertyVec[0].NumThreadsX; PSV.BaseData.NumThreadsY = MMI.EntryPropertyVec[0].NumThreadsY; PSV.BaseData.NumThreadsZ = MMI.EntryPropertyVec[0].NumThreadsZ; + if (MMI.EntryPropertyVec[0].WaveSizeMin) { + PSV.BaseData.MinimumWaveLaneCount = MMI.EntryPropertyVec[0].WaveSizeMin; + PSV.BaseData.MaximumWaveLaneCount = + MMI.EntryPropertyVec[0].WaveSizeMax + ? MMI.EntryPropertyVec[0].WaveSizeMax + : MMI.EntryPropertyVec[0].WaveSizeMin; + } break; default: break; diff --git a/llvm/test/CodeGen/DirectX/wavesize-md-valid.ll b/llvm/test/CodeGen/DirectX/wavesize-md-valid.ll index 63e8a59eb2648..3ad6c1d034252 100644 --- a/llvm/test/CodeGen/DirectX/wavesize-md-valid.ll +++ b/llvm/test/CodeGen/DirectX/wavesize-md-valid.ll @@ -4,6 +4,11 @@ ; RUN: opt -S --dxil-translate-metadata %t/max.ll | FileCheck %t/max.ll ; RUN: opt -S --dxil-translate-metadata %t/pref.ll | FileCheck %t/pref.ll +; RUN: llc --filetype=obj %t/only.ll -o - | obj2yaml | FileCheck %t/only.ll --check-prefix=OBJ +; RUN: llc --filetype=obj %t/min.ll -o - | obj2yaml | FileCheck %t/min.ll --check-prefix=OBJ +; RUN: llc --filetype=obj %t/max.ll -o - | obj2yaml | FileCheck %t/max.ll --check-prefix=OBJ +; RUN: llc --filetype=obj %t/pref.ll -o - | obj2yaml | FileCheck %t/pref.ll --check-prefix=OBJ + ; Test that wave size/range metadata is correctly generated with the correct tag ;--- only.ll @@ -13,6 +18,11 @@ ; CHECK: ![[#PROPS]] = !{{{.*}}i32 11, ![[#WAVE_SIZE:]]{{.*}}} ; CHECK: ![[#WAVE_SIZE]] = !{i32 16} +; OBJ: - Name: PSV0 +; OBJ: PSVInfo: +; OBJ: MinimumWaveLaneCount: 16 +; OBJ: MaximumWaveLaneCount: 16 + target triple = "dxil-unknown-shadermodel6.6-compute" define void @main() #0 { @@ -29,6 +39,11 @@ attributes #0 = { "hlsl.wavesize"="16,0,0" "hlsl.numthreads"="1,1,1" "hlsl.shade ; CHECK: ![[#PROPS]] = !{{{.*}}i32 23, ![[#WAVE_SIZE:]]{{.*}}} ; CHECK: ![[#WAVE_SIZE]] = !{i32 16, i32 0, i32 0} +; OBJ: - Name: PSV0 +; OBJ: PSVInfo: +; OBJ: MinimumWaveLaneCount: 16 +; OBJ: MaximumWaveLaneCount: 16 + target triple = "dxil-unknown-shadermodel6.8-compute" define void @main() #0 { @@ -45,6 +60,11 @@ attributes #0 = { "hlsl.wavesize"="16,0,0" "hlsl.numthreads"="1,1,1" "hlsl.shade ; CHECK: ![[#PROPS]] = !{{{.*}}i32 23, ![[#WAVE_SIZE:]]{{.*}}} ; CHECK: ![[#WAVE_SIZE]] = !{i32 16, i32 32, i32 0} +; OBJ: - Name: PSV0 +; OBJ: PSVInfo: +; OBJ: MinimumWaveLaneCount: 16 +; OBJ: MaximumWaveLaneCount: 32 + target triple = "dxil-unknown-shadermodel6.8-compute" define void @main() #0 { @@ -61,6 +81,11 @@ attributes #0 = { "hlsl.wavesize"="16,32,0" "hlsl.numthreads"="1,1,1" "hlsl.shad ; CHECK: ![[#PROPS]] = !{{{.*}}i32 23, ![[#WAVE_SIZE:]]{{.*}}} ; CHECK: ![[#WAVE_SIZE]] = !{i32 16, i32 64, i32 32} +; OBJ: - Name: PSV0 +; OBJ: PSVInfo: +; OBJ: MinimumWaveLaneCount: 16 +; OBJ: MaximumWaveLaneCount: 64 + target triple = "dxil-unknown-shadermodel6.8-compute" define void @main() #0 {