diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h index 818caccfe1998..76ce74d51b301 100644 --- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h +++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h @@ -15,6 +15,7 @@ #define LLVM_FRONTEND_HLSL_HLSLROOTSIGNATURE_H #include "llvm/Support/DXILABI.h" +#include "llvm/Support/raw_ostream.h" #include namespace llvm { @@ -58,6 +59,8 @@ struct Register { struct DescriptorTable { ShaderVisibility Visibility = ShaderVisibility::All; uint32_t NumClauses = 0; // The number of clauses in the table + + void dump(raw_ostream &OS) const; }; static const uint32_t NumDescriptorsUnbounded = 0xffffffff; @@ -86,6 +89,8 @@ struct DescriptorTableClause { break; } } + + void dump(raw_ostream &OS) const; }; // Models RootElement : DescriptorTable | DescriptorTableClause diff --git a/llvm/lib/Frontend/HLSL/CMakeLists.txt b/llvm/lib/Frontend/HLSL/CMakeLists.txt index 07a0c845ceef6..dd987544fe80b 100644 --- a/llvm/lib/Frontend/HLSL/CMakeLists.txt +++ b/llvm/lib/Frontend/HLSL/CMakeLists.txt @@ -1,6 +1,7 @@ add_llvm_component_library(LLVMFrontendHLSL CBuffer.cpp HLSLResource.cpp + HLSLRootSignature.cpp ADDITIONAL_HEADER_DIRS ${LLVM_MAIN_INCLUDE_DIR}/llvm/Frontend diff --git a/llvm/lib/Frontend/HLSL/HLSLRootSignature.cpp b/llvm/lib/Frontend/HLSL/HLSLRootSignature.cpp new file mode 100644 index 0000000000000..5351239b94b1e --- /dev/null +++ b/llvm/lib/Frontend/HLSL/HLSLRootSignature.cpp @@ -0,0 +1,149 @@ +//===- HLSLRootSignature.cpp - HLSL Root Signature helper objects ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file This file contains helpers for working with HLSL Root Signatures. +/// +//===----------------------------------------------------------------------===// + +#include "llvm/Frontend/HLSL/HLSLRootSignature.h" +#include "llvm/ADT/bit.h" + +namespace llvm { +namespace hlsl { +namespace rootsig { + +static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) { + switch (Reg.ViewType) { + case RegisterType::BReg: + OS << "b"; + break; + case RegisterType::TReg: + OS << "t"; + break; + case RegisterType::UReg: + OS << "u"; + break; + case RegisterType::SReg: + OS << "s"; + break; + } + OS << Reg.Number; + return OS; +} + +static raw_ostream &operator<<(raw_ostream &OS, + const ShaderVisibility &Visibility) { + switch (Visibility) { + case ShaderVisibility::All: + OS << "All"; + break; + case ShaderVisibility::Vertex: + OS << "Vertex"; + break; + case ShaderVisibility::Hull: + OS << "Hull"; + break; + case ShaderVisibility::Domain: + OS << "Domain"; + break; + case ShaderVisibility::Geometry: + OS << "Geometry"; + break; + case ShaderVisibility::Pixel: + OS << "Pixel"; + break; + case ShaderVisibility::Amplification: + OS << "Amplification"; + break; + case ShaderVisibility::Mesh: + OS << "Mesh"; + break; + } + + return OS; +} + +void DescriptorTable::dump(raw_ostream &OS) const { + OS << "DescriptorTable(numClauses = " << NumClauses + << ", visibility = " << Visibility << ")"; +} + +static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) { + switch (Type) { + case ClauseType::CBuffer: + OS << "CBV"; + break; + case ClauseType::SRV: + OS << "SRV"; + break; + case ClauseType::UAV: + OS << "UAV"; + break; + case ClauseType::Sampler: + OS << "Sampler"; + break; + } + + return OS; +} + +static raw_ostream &operator<<(raw_ostream &OS, + const DescriptorRangeFlags &Flags) { + bool FlagSet = false; + unsigned Remaining = llvm::to_underlying(Flags); + while (Remaining) { + unsigned Bit = 1u << llvm::countr_zero(Remaining); + if (Remaining & Bit) { + if (FlagSet) + OS << " | "; + + switch (static_cast(Bit)) { + case DescriptorRangeFlags::DescriptorsVolatile: + OS << "DescriptorsVolatile"; + break; + case DescriptorRangeFlags::DataVolatile: + OS << "DataVolatile"; + break; + case DescriptorRangeFlags::DataStaticWhileSetAtExecute: + OS << "DataStaticWhileSetAtExecute"; + break; + case DescriptorRangeFlags::DataStatic: + OS << "DataStatic"; + break; + case DescriptorRangeFlags::DescriptorsStaticKeepingBufferBoundsChecks: + OS << "DescriptorsStaticKeepingBufferBoundsChecks"; + break; + default: + OS << "invalid: " << Bit; + break; + } + + FlagSet = true; + } + Remaining &= ~Bit; + } + + if (!FlagSet) + OS << "None"; + + return OS; +} + +void DescriptorTableClause::dump(raw_ostream &OS) const { + OS << Type << "(" << Reg << ", numDescriptors = " << NumDescriptors + << ", space = " << Space << ", offset = "; + if (Offset == DescriptorTableOffsetAppend) + OS << "DescriptorTableOffsetAppend"; + else + OS << Offset; + OS << ", flags = " << Flags << ")"; +} + +} // namespace rootsig +} // namespace hlsl +} // namespace llvm diff --git a/llvm/unittests/Frontend/CMakeLists.txt b/llvm/unittests/Frontend/CMakeLists.txt index 85e113816e3bc..2119642769e3d 100644 --- a/llvm/unittests/Frontend/CMakeLists.txt +++ b/llvm/unittests/Frontend/CMakeLists.txt @@ -1,6 +1,7 @@ set(LLVM_LINK_COMPONENTS Analysis Core + FrontendHLSL FrontendOpenACC FrontendOpenMP Passes @@ -10,6 +11,7 @@ set(LLVM_LINK_COMPONENTS ) add_llvm_unittest(LLVMFrontendTests + HLSLRootSignatureDumpTest.cpp OpenACCTest.cpp OpenMPContextTest.cpp OpenMPIRBuilderTest.cpp diff --git a/llvm/unittests/Frontend/HLSLRootSignatureDumpTest.cpp b/llvm/unittests/Frontend/HLSLRootSignatureDumpTest.cpp new file mode 100644 index 0000000000000..ba1fbfd1f8708 --- /dev/null +++ b/llvm/unittests/Frontend/HLSLRootSignatureDumpTest.cpp @@ -0,0 +1,111 @@ +//===-------- HLSLRootSignatureDumpTest.cpp - RootSignature dump tests ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Frontend/HLSL/HLSLRootSignature.h" +#include "gtest/gtest.h" + +using namespace llvm::hlsl::rootsig; + +namespace { + +TEST(HLSLRootSignatureTest, DescriptorCBVClauseDump) { + DescriptorTableClause Clause; + Clause.Type = ClauseType::CBuffer; + Clause.Reg = {RegisterType::BReg, 0}; + Clause.setDefaultFlags(); + + std::string Out; + llvm::raw_string_ostream OS(Out); + Clause.dump(OS); + OS.flush(); + + std::string Expected = "CBV(b0, numDescriptors = 1, space = 0, " + "offset = DescriptorTableOffsetAppend, " + "flags = DataStaticWhileSetAtExecute)"; + EXPECT_EQ(Out, Expected); +} + +TEST(HLSLRootSignatureTest, DescriptorSRVClauseDump) { + DescriptorTableClause Clause; + Clause.Type = ClauseType::SRV; + Clause.Reg = {RegisterType::TReg, 0}; + Clause.NumDescriptors = 2; + Clause.Space = 42; + Clause.Offset = 3; + Clause.Flags = DescriptorRangeFlags::None; + + std::string Out; + llvm::raw_string_ostream OS(Out); + Clause.dump(OS); + OS.flush(); + + std::string Expected = + "SRV(t0, numDescriptors = 2, space = 42, offset = 3, flags = None)"; + EXPECT_EQ(Out, Expected); +} + +TEST(HLSLRootSignatureTest, DescriptorUAVClauseDump) { + DescriptorTableClause Clause; + Clause.Type = ClauseType::UAV; + Clause.Reg = {RegisterType::UReg, 92374}; + Clause.NumDescriptors = 3298; + Clause.Space = 932847; + Clause.Offset = 1; + Clause.Flags = DescriptorRangeFlags::ValidFlags; + + std::string Out; + llvm::raw_string_ostream OS(Out); + Clause.dump(OS); + OS.flush(); + + std::string Expected = + "UAV(u92374, numDescriptors = 3298, space = 932847, offset = 1, flags = " + "DescriptorsVolatile | " + "DataVolatile | " + "DataStaticWhileSetAtExecute | " + "DataStatic | " + "DescriptorsStaticKeepingBufferBoundsChecks)"; + EXPECT_EQ(Out, Expected); +} + +TEST(HLSLRootSignatureTest, DescriptorSamplerClauseDump) { + DescriptorTableClause Clause; + Clause.Type = ClauseType::Sampler; + Clause.Reg = {RegisterType::SReg, 0}; + Clause.NumDescriptors = 2; + Clause.Space = 42; + Clause.Offset = DescriptorTableOffsetAppend; + Clause.Flags = DescriptorRangeFlags::ValidSamplerFlags; + + std::string Out; + llvm::raw_string_ostream OS(Out); + Clause.dump(OS); + OS.flush(); + + std::string Expected = "Sampler(s0, numDescriptors = 2, space = 42, offset = " + "DescriptorTableOffsetAppend, " + "flags = DescriptorsVolatile)"; + EXPECT_EQ(Out, Expected); +} + +TEST(HLSLRootSignatureTest, DescriptorTableDump) { + DescriptorTable Table; + Table.NumClauses = 4; + Table.Visibility = ShaderVisibility::Geometry; + + std::string Out; + llvm::raw_string_ostream OS(Out); + Table.dump(OS); + OS.flush(); + + std::string Expected = + "DescriptorTable(numClauses = 4, visibility = Geometry)"; + EXPECT_EQ(Out, Expected); +} + +} // namespace