Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions llvm/include/llvm/Analysis/DXILMetadataAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ struct EntryProperties {
unsigned NumThreadsX{0}; // X component
unsigned NumThreadsY{0}; // Y component
unsigned NumThreadsZ{0}; // Z component
unsigned WaveSizeMin{0}; // Minimum component
unsigned WaveSizeMax{0}; // Maximum component
unsigned WaveSizePref{0}; // Preferred component

EntryProperties(const Function *Fn = nullptr) : Entry(Fn) {};
};
Expand Down
16 changes: 16 additions & 0 deletions llvm/lib/Analysis/DXILMetadataAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,22 @@ static ModuleMetadataInfo collectMetadataInfo(Module &M) {
Success = llvm::to_integer(NumThreadsVec[2], EFP.NumThreadsZ, 10);
assert(Success && "Failed to parse Z component of numthreads");
}
// Get wavesize attribute value, if one exists
StringRef WaveSizeStr =
F.getFnAttribute("hlsl.wavesize").getValueAsString();
if (!WaveSizeStr.empty()) {
SmallVector<StringRef> WaveSizeVec;
WaveSizeStr.split(WaveSizeVec, ',');
assert(WaveSizeVec.size() == 3 && "Invalid wavesize specified");
// Read in the three component values of numthreads
[[maybe_unused]] bool Success =
llvm::to_integer(WaveSizeVec[0], EFP.WaveSizeMin, 10);
assert(Success && "Failed to parse Min component of wavesize");
Success = llvm::to_integer(WaveSizeVec[1], EFP.WaveSizeMax, 10);
assert(Success && "Failed to parse Max component of wavesize");
Success = llvm::to_integer(WaveSizeVec[2], EFP.WaveSizePref, 10);
assert(Success && "Failed to parse Preferred component of wavesize");
}
MMDAI.EntryPropertyVec.push_back(EFP);
}
return MMDAI;
Expand Down
66 changes: 53 additions & 13 deletions llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ enum class EntryPropsTag {
ASStateTag,
WaveSize,
EntryRootSig,
WaveRange = 23,
};

} // namespace
Expand Down Expand Up @@ -177,14 +178,15 @@ getTagValueAsMetadata(EntryPropsTag Tag, uint64_t Value, LLVMContext &Ctx) {
case EntryPropsTag::ASStateTag:
case EntryPropsTag::WaveSize:
case EntryPropsTag::EntryRootSig:
case EntryPropsTag::WaveRange:
llvm_unreachable("NYI: Unhandled entry property tag");
}
return MDVals;
}

static MDTuple *
getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
const Triple::EnvironmentType ShaderProfile) {
static MDTuple *getEntryPropAsMetadata(Module &M, const EntryProperties &EP,
uint64_t EntryShaderFlags,
const ModuleMetadataInfo &MMDI) {
SmallVector<Metadata *> MDVals;
LLVMContext &Ctx = EP.Entry->getContext();
if (EntryShaderFlags != 0)
Expand All @@ -195,12 +197,13 @@ getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
// FIXME: support more props.
// See https://github.com/llvm/llvm-project/issues/57948.
// Add shader kind for lib entries.
if (ShaderProfile == Triple::EnvironmentType::Library &&
if (MMDI.ShaderProfile == Triple::EnvironmentType::Library &&
EP.ShaderStage != Triple::EnvironmentType::Library)
MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderKind,
getShaderStage(EP.ShaderStage), Ctx));

if (EP.ShaderStage == Triple::EnvironmentType::Compute) {
// Handle mandatory "hlsl.numthreads"
MDVals.emplace_back(ConstantAsMetadata::get(ConstantInt::get(
Type::getInt32Ty(Ctx), static_cast<int>(EntryPropsTag::NumThreads))));
Metadata *NumThreadVals[] = {ConstantAsMetadata::get(ConstantInt::get(
Expand All @@ -210,8 +213,48 @@ getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
ConstantAsMetadata::get(ConstantInt::get(
Type::getInt32Ty(Ctx), EP.NumThreadsZ))};
MDVals.emplace_back(MDNode::get(Ctx, NumThreadVals));

// Handle optional "hlsl.wavesize". The fields are optionally represented
// if they are non-zero.
if (EP.WaveSizeMin != 0) {
bool IsWaveRange = VersionTuple(6, 8) <= MMDI.ShaderModelVersion;
bool IsWaveSize =
!IsWaveRange && VersionTuple(6, 6) <= MMDI.ShaderModelVersion;

if (!IsWaveRange && !IsWaveSize) {
reportError(M, "Shader model 6.6 or greater is required to specify "
"the \"hlsl.wavesize\" function attribute");
return nullptr;
}

// A range is being specified if EP.WaveSizeMax != 0
if (EP.WaveSizeMax && !IsWaveRange) {
reportError(
M, "Shader model 6.8 or greater is required to specify "
"wave size range values of the \"hlsl.wavesize\" function "
"attribute");
return nullptr;
}

EntryPropsTag Tag =
IsWaveSize ? EntryPropsTag::WaveSize : EntryPropsTag::WaveRange;
MDVals.emplace_back(ConstantAsMetadata::get(
ConstantInt::get(Type::getInt32Ty(Ctx), static_cast<int>(Tag))));

SmallVector<Metadata *> WaveSizeVals = {ConstantAsMetadata::get(
ConstantInt::get(Type::getInt32Ty(Ctx), EP.WaveSizeMin))};
if (IsWaveRange) {
WaveSizeVals.push_back(ConstantAsMetadata::get(
ConstantInt::get(Type::getInt32Ty(Ctx), EP.WaveSizeMax)));
WaveSizeVals.push_back(ConstantAsMetadata::get(
ConstantInt::get(Type::getInt32Ty(Ctx), EP.WaveSizePref)));
}

MDVals.emplace_back(MDNode::get(Ctx, WaveSizeVals));
}
}
}

if (MDVals.empty())
return nullptr;
return MDNode::get(Ctx, MDVals);
Expand All @@ -236,12 +279,11 @@ static MDTuple *constructEntryMetadata(const Function *EntryFn,
return MDNode::get(Ctx, MDVals);
}

static MDTuple *emitEntryMD(const EntryProperties &EP, MDTuple *Signatures,
MDNode *MDResources,
static MDTuple *emitEntryMD(Module &M, const EntryProperties &EP,
MDTuple *Signatures, MDNode *MDResources,
const uint64_t EntryShaderFlags,
const Triple::EnvironmentType ShaderProfile) {
MDTuple *Properties =
getEntryPropAsMetadata(EP, EntryShaderFlags, ShaderProfile);
const ModuleMetadataInfo &MMDI) {
MDTuple *Properties = getEntryPropAsMetadata(M, EP, EntryShaderFlags, MMDI);
return constructEntryMetadata(EP.Entry, Signatures, MDResources, Properties,
EP.Entry->getContext());
}
Expand Down Expand Up @@ -523,10 +565,8 @@ static void translateGlobalMetadata(Module &M, DXILResourceMap &DRM,
Twine(Triple::getEnvironmentTypeName(MMDI.ShaderProfile) +
"'"));
}

EntryFnMDNodes.emplace_back(emitEntryMD(EntryProp, Signatures, ResourceMD,
EntryShaderFlags,
MMDI.ShaderProfile));
EntryFnMDNodes.emplace_back(emitEntryMD(
M, EntryProp, Signatures, ResourceMD, EntryShaderFlags, MMDI));
}

NamedMDNode *EntryPointsNamedMD =
Expand Down
31 changes: 31 additions & 0 deletions llvm/test/CodeGen/DirectX/wavesize-md-errs.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
; RUN: split-file %s %t
; RUN: not opt -S --dxil-translate-metadata %t/low-sm.ll 2>&1 | FileCheck %t/low-sm.ll
; RUN: not opt -S --dxil-translate-metadata %t/low-sm-for-range.ll 2>&1 | FileCheck %t/low-sm-for-range.ll

; Test that wavesize metadata is only allowed on applicable shader model versions

;--- low-sm.ll

; CHECK: Shader model 6.6 or greater is required to specify the "hlsl.wavesize" function attribute

target triple = "dxil-unknown-shadermodel6.5-compute"

define void @main() #0 {
entry:
ret void
}

attributes #0 = { "hlsl.wavesize"="16,0,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }

;--- low-sm-for-range.ll

; CHECK: Shader model 6.8 or greater is required to specify wave size range values of the "hlsl.wavesize" function attribute

target triple = "dxil-unknown-shadermodel6.7-compute"

define void @main() #0 {
entry:
ret void
}

attributes #0 = { "hlsl.wavesize"="16,32,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
71 changes: 71 additions & 0 deletions llvm/test/CodeGen/DirectX/wavesize-md-valid.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
; RUN: split-file %s %t
; RUN: opt -S --dxil-translate-metadata %t/only.ll | FileCheck %t/only.ll
; RUN: opt -S --dxil-translate-metadata %t/min.ll | FileCheck %t/min.ll
; RUN: opt -S --dxil-translate-metadata %t/max.ll | FileCheck %t/max.ll
; RUN: opt -S --dxil-translate-metadata %t/pref.ll | FileCheck %t/pref.ll

; Test that wave size/range metadata is correctly generated with the correct tag

;--- only.ll

; CHECK: !dx.entryPoints = !{![[#ENTRY:]]}
; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]}
; CHECK: ![[#PROPS]] = !{{{.*}}i32 11, ![[#WAVE_SIZE:]]{{.*}}}
; CHECK: ![[#WAVE_SIZE]] = !{i32 16}

target triple = "dxil-unknown-shadermodel6.6-compute"

define void @main() #0 {
entry:
ret void
}

attributes #0 = { "hlsl.wavesize"="16,0,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }

;--- min.ll

; CHECK: !dx.entryPoints = !{![[#ENTRY:]]}
; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]}
; CHECK: ![[#PROPS]] = !{{{.*}}i32 23, ![[#WAVE_SIZE:]]{{.*}}}
; CHECK: ![[#WAVE_SIZE]] = !{i32 16, i32 0, i32 0}

target triple = "dxil-unknown-shadermodel6.8-compute"

define void @main() #0 {
entry:
ret void
}

attributes #0 = { "hlsl.wavesize"="16,0,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }

;--- max.ll

; CHECK: !dx.entryPoints = !{![[#ENTRY:]]}
; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]}
; CHECK: ![[#PROPS]] = !{{{.*}}i32 23, ![[#WAVE_SIZE:]]{{.*}}}
; CHECK: ![[#WAVE_SIZE]] = !{i32 16, i32 32, i32 0}

target triple = "dxil-unknown-shadermodel6.8-compute"

define void @main() #0 {
entry:
ret void
}

attributes #0 = { "hlsl.wavesize"="16,32,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }

;--- pref.ll

; CHECK: !dx.entryPoints = !{![[#ENTRY:]]}
; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]}
; CHECK: ![[#PROPS]] = !{{{.*}}i32 23, ![[#WAVE_SIZE:]]{{.*}}}
; CHECK: ![[#WAVE_SIZE]] = !{i32 16, i32 64, i32 32}

target triple = "dxil-unknown-shadermodel6.8-compute"

define void @main() #0 {
entry:
ret void
}

attributes #0 = { "hlsl.wavesize"="16,64,32" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }