diff --git a/llvm/include/llvm/BinaryFormat/DXContainer.h b/llvm/include/llvm/BinaryFormat/DXContainer.h index bd5a796c0b31c..28905e27837a7 100644 --- a/llvm/include/llvm/BinaryFormat/DXContainer.h +++ b/llvm/include/llvm/BinaryFormat/DXContainer.h @@ -320,7 +320,7 @@ ArrayRef> getResourceKinds(); #define RESOURCE_FLAG(Index, Enum) bool Enum = false; struct ResourceFlags { - ResourceFlags() {}; + ResourceFlags() : Flags(0U) {}; struct FlagsBits { #include "llvm/BinaryFormat/DXContainerConstants.def" }; diff --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp index 5508af40663b1..c7a130a1f9c8a 100644 --- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp +++ b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp @@ -186,51 +186,71 @@ void DXContainerGlobals::addResourcesForPSV(Module &M, PSVRuntimeInfo &PSV) { DXILResourceTypeMap &DRTM = getAnalysis().getResourceTypeMap(); - for (const dxil::ResourceBindingInfo &RBI : DBM) { + auto MakeBinding = + [](const dxil::ResourceBindingInfo::ResourceBinding &Binding, + const dxbc::PSV::ResourceType Type, const dxil::ResourceKind Kind, + const dxbc::PSV::ResourceFlags Flags = dxbc::PSV::ResourceFlags()) { + dxbc::PSV::v2::ResourceBindInfo BindInfo; + BindInfo.Type = Type; + BindInfo.LowerBound = Binding.LowerBound; + BindInfo.UpperBound = Binding.LowerBound + Binding.Size - 1; + BindInfo.Space = Binding.Space; + BindInfo.Kind = static_cast(Kind); + BindInfo.Flags = Flags; + return BindInfo; + }; + + for (const dxil::ResourceBindingInfo &RBI : DBM.cbuffers()) { + const dxil::ResourceBindingInfo::ResourceBinding &Binding = + RBI.getBinding(); + PSV.Resources.push_back(MakeBinding(Binding, dxbc::PSV::ResourceType::CBV, + dxil::ResourceKind::CBuffer)); + } + for (const dxil::ResourceBindingInfo &RBI : DBM.samplers()) { + const dxil::ResourceBindingInfo::ResourceBinding &Binding = + RBI.getBinding(); + PSV.Resources.push_back(MakeBinding(Binding, + dxbc::PSV::ResourceType::Sampler, + dxil::ResourceKind::Sampler)); + } + for (const dxil::ResourceBindingInfo &RBI : DBM.srvs()) { const dxil::ResourceBindingInfo::ResourceBinding &Binding = RBI.getBinding(); - dxbc::PSV::v2::ResourceBindInfo BindInfo; - BindInfo.LowerBound = Binding.LowerBound; - BindInfo.UpperBound = Binding.LowerBound + Binding.Size - 1; - BindInfo.Space = Binding.Space; dxil::ResourceTypeInfo &TypeInfo = DRTM[RBI.getHandleTy()]; - dxbc::PSV::ResourceType ResType = dxbc::PSV::ResourceType::Invalid; - bool IsUAV = TypeInfo.getResourceClass() == dxil::ResourceClass::UAV; - switch (TypeInfo.getResourceKind()) { - case dxil::ResourceKind::Sampler: - ResType = dxbc::PSV::ResourceType::Sampler; - break; - case dxil::ResourceKind::CBuffer: - ResType = dxbc::PSV::ResourceType::CBV; - break; - case dxil::ResourceKind::StructuredBuffer: - ResType = IsUAV ? dxbc::PSV::ResourceType::UAVStructured - : dxbc::PSV::ResourceType::SRVStructured; - if (IsUAV && TypeInfo.getUAV().HasCounter) - ResType = dxbc::PSV::ResourceType::UAVStructuredWithCounter; - break; - case dxil::ResourceKind::RTAccelerationStructure: + dxbc::PSV::ResourceType ResType; + if (TypeInfo.isStruct()) + ResType = dxbc::PSV::ResourceType::SRVStructured; + else if (TypeInfo.isTyped()) + ResType = dxbc::PSV::ResourceType::SRVTyped; + else ResType = dxbc::PSV::ResourceType::SRVRaw; - break; - case dxil::ResourceKind::RawBuffer: - ResType = IsUAV ? dxbc::PSV::ResourceType::UAVRaw - : dxbc::PSV::ResourceType::SRVRaw; - break; - default: - ResType = IsUAV ? dxbc::PSV::ResourceType::UAVTyped - : dxbc::PSV::ResourceType::SRVTyped; - break; - } - BindInfo.Type = ResType; - - BindInfo.Kind = - static_cast(TypeInfo.getResourceKind()); + + PSV.Resources.push_back( + MakeBinding(Binding, ResType, TypeInfo.getResourceKind())); + } + for (const dxil::ResourceBindingInfo &RBI : DBM.uavs()) { + const dxil::ResourceBindingInfo::ResourceBinding &Binding = + RBI.getBinding(); + + dxil::ResourceTypeInfo &TypeInfo = DRTM[RBI.getHandleTy()]; + dxbc::PSV::ResourceType ResType; + if (TypeInfo.getUAV().HasCounter) + ResType = dxbc::PSV::ResourceType::UAVStructuredWithCounter; + else if (TypeInfo.isStruct()) + ResType = dxbc::PSV::ResourceType::UAVStructured; + else if (TypeInfo.isTyped()) + ResType = dxbc::PSV::ResourceType::UAVTyped; + else + ResType = dxbc::PSV::ResourceType::UAVRaw; + + dxbc::PSV::ResourceFlags Flags; // TODO: Add support for dxbc::PSV::ResourceFlag::UsedByAtomic64, tracking // with https://github.com/llvm/llvm-project/issues/104392 - BindInfo.Flags.Flags = 0u; + Flags.Flags = 0u; - PSV.Resources.emplace_back(BindInfo); + PSV.Resources.push_back( + MakeBinding(Binding, ResType, TypeInfo.getResourceKind(), Flags)); } } diff --git a/llvm/test/CodeGen/DirectX/ContainerData/PSVResources-order.ll b/llvm/test/CodeGen/DirectX/ContainerData/PSVResources-order.ll new file mode 100644 index 0000000000000..734149eec598e --- /dev/null +++ b/llvm/test/CodeGen/DirectX/ContainerData/PSVResources-order.ll @@ -0,0 +1,26 @@ +; RUN: llc %s --filetype=obj -o - | obj2yaml | FileCheck %s + +; Check that resources are emitted to the object in the order that matches what +; the DXIL validator expects: CBuffers, Samplers, SRVs, and then UAVs. + +; CHECK: Resources: +; CHECK: - Type: CBV +; TODO: - Type: Sampler +; CHECK: - Type: SRVRaw +; CHECK: - Type: UAVTyped + +target triple = "dxil-unknown-shadermodel6.0-compute" + +define void @main() #0 { + %uav0 = call target("dx.TypedBuffer", i32, 1, 0, 1) + @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_i32_1_0t( + i32 2, i32 7, i32 1, i32 0, i1 false) + %srv0 = call target("dx.RawBuffer", i8, 0, 0) + @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_i8_0_0t( + i32 1, i32 8, i32 1, i32 0, i1 false) + %cbuf = call target("dx.CBuffer", target("dx.Layout", {float}, 4, 0)) + @llvm.dx.resource.handlefrombinding(i32 3, i32 2, i32 1, i32 0, i1 false) + ret void +} + +attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" } diff --git a/llvm/test/CodeGen/DirectX/ContainerData/PSVResources.ll b/llvm/test/CodeGen/DirectX/ContainerData/PSVResources.ll index ce67812c3988f..cea8ba2f067c1 100644 --- a/llvm/test/CodeGen/DirectX/ContainerData/PSVResources.ll +++ b/llvm/test/CodeGen/DirectX/ContainerData/PSVResources.ll @@ -6,6 +6,17 @@ target triple = "dxil-unknown-shadermodel6.0-compute" define void @main() #0 { + ; cbuffer : register(b2, space3) { float x; } +; CHECK: - Type: CBV +; CHECK: Space: 3 +; CHECK: LowerBound: 2 +; CHECK: UpperBound: 2 +; CHECK: Kind: CBuffer +; CHECK: Flags: +; CHECK: UsedByAtomic64: false + %cbuf = call target("dx.CBuffer", target("dx.Layout", {float}, 4, 0)) + @llvm.dx.resource.handlefrombinding(i32 3, i32 2, i32 1, i32 0, i1 false) + ; ByteAddressBuffer Buf : register(t8, space1) ; CHECK: - Type: SRVRaw ; CHECK: Space: 1