diff --git a/llvm/include/llvm/Analysis/DXILResource.h b/llvm/include/llvm/Analysis/DXILResource.h index 87c5615c28ee0..d4b1a9e2ca340 100644 --- a/llvm/include/llvm/Analysis/DXILResource.h +++ b/llvm/include/llvm/Analysis/DXILResource.h @@ -446,6 +446,13 @@ class DXILBindingMap { return Pos == CallMap.end() ? Infos.end() : (Infos.begin() + Pos->second); } + /// Resolves a resource handle into a vector of ResourceBindingInfos that + /// represent the possible unique creations of the handle. Certain cases are + /// ambiguous so multiple creation instructions may be returned. The resulting + /// ResourceBindingInfo can be used to depuplicate unique handles that + /// reference the same resource + SmallVector findByUse(const Value *Key) const; + const_iterator find(const CallInst *Key) const { auto Pos = CallMap.find(Key); return Pos == CallMap.end() ? Infos.end() : (Infos.begin() + Pos->second); diff --git a/llvm/lib/Analysis/DXILResource.cpp b/llvm/lib/Analysis/DXILResource.cpp index 7f28e63cc117d..4ffc9dbebda8d 100644 --- a/llvm/lib/Analysis/DXILResource.cpp +++ b/llvm/lib/Analysis/DXILResource.cpp @@ -770,6 +770,45 @@ void DXILBindingMap::print(raw_ostream &OS, DXILResourceTypeMap &DRTM, } } +SmallVector +DXILBindingMap::findByUse(const Value *Key) const { + if (const PHINode *Phi = dyn_cast(Key)) { + SmallVector Children; + for (const Value *V : Phi->operands()) { + Children.append(findByUse(V)); + } + return Children; + } + + const CallInst *CI = dyn_cast(Key); + if (!CI) + return {}; + + switch (CI->getIntrinsicID()) { + // Found the create, return the binding + case Intrinsic::dx_resource_handlefrombinding: { + const auto *It = find(CI); + assert(It != Infos.end() && "HandleFromBinding must be in resource map"); + return {*It}; + } + default: + break; + } + + // Check if any of the parameters are the resource we are following. If so + // keep searching. If none of them are return an empty list + const Type *UseType = CI->getType(); + SmallVector Children; + for (const Value *V : CI->args()) { + if (V->getType() != UseType) + continue; + + Children.append(findByUse(V)); + } + + return Children; +} + //===----------------------------------------------------------------------===// AnalysisKey DXILResourceTypeAnalysis::Key; diff --git a/llvm/unittests/Target/DirectX/CMakeLists.txt b/llvm/unittests/Target/DirectX/CMakeLists.txt index 626c0d6384268..fd0d5a0dd52c1 100644 --- a/llvm/unittests/Target/DirectX/CMakeLists.txt +++ b/llvm/unittests/Target/DirectX/CMakeLists.txt @@ -8,10 +8,12 @@ set(LLVM_LINK_COMPONENTS Core DirectXCodeGen DirectXPointerTypeAnalysis + Passes Support ) add_llvm_target_unittest(DirectXTests CBufferDataLayoutTests.cpp PointerTypeAnalysisTests.cpp + UniqueResourceFromUseTests.cpp ) diff --git a/llvm/unittests/Target/DirectX/UniqueResourceFromUseTests.cpp b/llvm/unittests/Target/DirectX/UniqueResourceFromUseTests.cpp new file mode 100644 index 0000000000000..f272381c0c250 --- /dev/null +++ b/llvm/unittests/Target/DirectX/UniqueResourceFromUseTests.cpp @@ -0,0 +1,283 @@ +//===- llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "DirectXTargetMachine.h" +#include "llvm/Analysis/DXILResource.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/CodeGen/CommandFlags.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/SourceMgr.h" + +#include "gtest/gtest.h" + +using namespace llvm; +using namespace llvm::dxil; + +namespace { +class UniqueResourceFromUseTest : public testing::Test { +protected: + PassBuilder *PB; + ModuleAnalysisManager *MAM; + + virtual void SetUp() { + MAM = new ModuleAnalysisManager(); + PB = new PassBuilder(); + PB->registerModuleAnalyses(*MAM); + MAM->registerPass([&] { return DXILResourceTypeAnalysis(); }); + MAM->registerPass([&] { return DXILResourceBindingAnalysis(); }); + } + + virtual void TearDown() { + delete PB; + delete MAM; + } +}; + +TEST_F(UniqueResourceFromUseTest, TestTrivialUse) { + StringRef Assembly = R"( +define void @main() { +entry: + %handle = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 1, i32 2, i32 3, i32 4, i1 false) + call void @a.func(target("dx.RawBuffer", float, 1, 0) %handle) + call void @a.func(target("dx.RawBuffer", float, 1, 0) %handle) + ret void +} + +declare target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32, i32, i32, i32, i1) +declare void @a.func(target("dx.RawBuffer", float, 1, 0) %handle) + )"; + + LLVMContext Context; + SMDiagnostic Error; + auto M = parseAssemblyString(Assembly, Error, Context); + ASSERT_TRUE(M) << "Bad assembly?"; + + const DXILBindingMap &DBM = MAM->getResult(*M); + for (const Function &F : M->functions()) { + if (F.getName() != "a.func") { + continue; + } + + unsigned CalledResources = 0; + + for (const User *U : F.users()) { + const CallInst *CI = cast(U); + const Value *Handle = CI->getArgOperand(0); + const auto Bindings = DBM.findByUse(Handle); + ASSERT_EQ(Bindings.size(), 1u) + << "Handle should resolve into one resource"; + + auto Binding = Bindings[0].getBinding(); + EXPECT_EQ(0u, Binding.RecordID); + EXPECT_EQ(1u, Binding.Space); + EXPECT_EQ(2u, Binding.LowerBound); + EXPECT_EQ(3u, Binding.Size); + + CalledResources++; + } + + EXPECT_EQ(2u, CalledResources) + << "Expected 2 resolved call to create resource"; + } +} + +TEST_F(UniqueResourceFromUseTest, TestIndirectUse) { + StringRef Assembly = R"( +define void @foo() { + %handle = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 1, i32 2, i32 3, i32 4, i1 false) + %handle2 = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %handle) + %handle3 = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %handle2) + %handle4 = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %handle3) + call void @a.func(target("dx.RawBuffer", float, 1, 0) %handle4) + ret void +} + +declare target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32, i32, i32, i32, i1) +declare void @a.func(target("dx.RawBuffer", float, 1, 0) %handle) +declare target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %handle) + )"; + + LLVMContext Context; + SMDiagnostic Error; + auto M = parseAssemblyString(Assembly, Error, Context); + ASSERT_TRUE(M) << "Bad assembly?"; + + const DXILBindingMap &DBM = MAM->getResult(*M); + for (const Function &F : M->functions()) { + if (F.getName() != "a.func") { + continue; + } + + unsigned CalledResources = 0; + + for (const User *U : F.users()) { + const CallInst *CI = cast(U); + const Value *Handle = CI->getArgOperand(0); + const auto Bindings = DBM.findByUse(Handle); + ASSERT_EQ(Bindings.size(), 1u) + << "Handle should resolve into one resource"; + + auto Binding = Bindings[0].getBinding(); + EXPECT_EQ(0u, Binding.RecordID); + EXPECT_EQ(1u, Binding.Space); + EXPECT_EQ(2u, Binding.LowerBound); + EXPECT_EQ(3u, Binding.Size); + + CalledResources++; + } + + EXPECT_EQ(1u, CalledResources) + << "Expected 1 resolved call to create resource"; + } +} + +TEST_F(UniqueResourceFromUseTest, TestAmbigousIndirectUse) { + StringRef Assembly = R"( +define void @foo() { + %foo = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 1, i32 1, i32 1, i32 1, i1 false) + %bar = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 2, i32 2, i32 2, i32 2, i1 false) + %baz = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 3, i32 3, i32 3, i32 3, i1 false) + %bat = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 4, i32 4, i32 4, i32 4, i1 false) + %a = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %foo, target("dx.RawBuffer", float, 1, 0) %bar) + %b = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %baz, target("dx.RawBuffer", float, 1, 0) %bat) + %handle = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %a, target("dx.RawBuffer", float, 1, 0) %b) + call void @a.func(target("dx.RawBuffer", float, 1, 0) %handle) + ret void +} + +declare target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32, i32, i32, i32, i1) +declare void @a.func(target("dx.RawBuffer", float, 1, 0) %handle) +declare target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %x, target("dx.RawBuffer", float, 1, 0) %y) + )"; + + LLVMContext Context; + SMDiagnostic Error; + auto M = parseAssemblyString(Assembly, Error, Context); + ASSERT_TRUE(M) << "Bad assembly?"; + + const DXILBindingMap &DBM = MAM->getResult(*M); + for (const Function &F : M->functions()) { + if (F.getName() != "a.func") { + continue; + } + + unsigned CalledResources = 0; + + for (const User *U : F.users()) { + const CallInst *CI = cast(U); + const Value *Handle = CI->getArgOperand(0); + const auto Bindings = DBM.findByUse(Handle); + ASSERT_EQ(Bindings.size(), 4u) + << "Handle should resolve into four resources"; + + auto Binding = Bindings[0].getBinding(); + EXPECT_EQ(0u, Binding.RecordID); + EXPECT_EQ(1u, Binding.Space); + EXPECT_EQ(1u, Binding.LowerBound); + EXPECT_EQ(1u, Binding.Size); + + Binding = Bindings[1].getBinding(); + EXPECT_EQ(1u, Binding.RecordID); + EXPECT_EQ(2u, Binding.Space); + EXPECT_EQ(2u, Binding.LowerBound); + EXPECT_EQ(2u, Binding.Size); + + Binding = Bindings[2].getBinding(); + EXPECT_EQ(2u, Binding.RecordID); + EXPECT_EQ(3u, Binding.Space); + EXPECT_EQ(3u, Binding.LowerBound); + EXPECT_EQ(3u, Binding.Size); + + Binding = Bindings[3].getBinding(); + EXPECT_EQ(3u, Binding.RecordID); + EXPECT_EQ(4u, Binding.Space); + EXPECT_EQ(4u, Binding.LowerBound); + EXPECT_EQ(4u, Binding.Size); + + CalledResources++; + } + + EXPECT_EQ(1u, CalledResources) + << "Expected 1 resolved call to create resource"; + } +} + +TEST_F(UniqueResourceFromUseTest, TestConditionalUse) { + StringRef Assembly = R"( +define void @foo(i32 %n) { +entry: + %x = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 1, i32 1, i32 1, i32 1, i1 false) + %y = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 4, i32 4, i32 4, i32 4, i1 false) + %cond = icmp eq i32 %n, 0 + br i1 %cond, label %bb.true, label %bb.false + +bb.true: + %handle_t = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %x) + br label %bb.exit + +bb.false: + %handle_f = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %y) + br label %bb.exit + +bb.exit: + %handle = phi target("dx.RawBuffer", float, 1, 0) [ %handle_t, %bb.true ], [ %handle_f, %bb.false ] + call void @a.func(target("dx.RawBuffer", float, 1, 0) %handle) + ret void +} + +declare target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32, i32, i32, i32, i1) +declare void @a.func(target("dx.RawBuffer", float, 1, 0) %handle) +declare target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %x) + )"; + + LLVMContext Context; + SMDiagnostic Error; + auto M = parseAssemblyString(Assembly, Error, Context); + ASSERT_TRUE(M) << "Bad assembly?"; + + const DXILBindingMap &DBM = MAM->getResult(*M); + for (const Function &F : M->functions()) { + if (F.getName() != "a.func") { + continue; + } + + unsigned CalledResources = 0; + + for (const User *U : F.users()) { + const CallInst *CI = cast(U); + const Value *Handle = CI->getArgOperand(0); + const auto Bindings = DBM.findByUse(Handle); + ASSERT_EQ(Bindings.size(), 2u) + << "Handle should resolve into four resources"; + + auto Binding = Bindings[0].getBinding(); + EXPECT_EQ(0u, Binding.RecordID); + EXPECT_EQ(1u, Binding.Space); + EXPECT_EQ(1u, Binding.LowerBound); + EXPECT_EQ(1u, Binding.Size); + + Binding = Bindings[1].getBinding(); + EXPECT_EQ(1u, Binding.RecordID); + EXPECT_EQ(4u, Binding.Space); + EXPECT_EQ(4u, Binding.LowerBound); + EXPECT_EQ(4u, Binding.Size); + + CalledResources++; + } + + EXPECT_EQ(1u, CalledResources) + << "Expected 1 resolved call to create resource"; + } +} + +} // namespace