diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp index 9598622f20225..f899b8b67affe 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp @@ -519,8 +519,11 @@ bool PartialOrderingVisitor::CanBeVisited(BasicBlock *BB) const { } size_t PartialOrderingVisitor::GetNodeRank(BasicBlock *BB) const { - size_t result = 0; + auto It = BlockToOrder.find(BB); + if (It != BlockToOrder.end()) + return It->second.Rank; + size_t result = 0; for (BasicBlock *P : predecessors(BB)) { // Ignore back-edges. if (DT.dominates(BB, P)) @@ -552,15 +555,20 @@ size_t PartialOrderingVisitor::visit(BasicBlock *BB, size_t Unused) { ToVisit.push(BB); Queued.insert(BB); + size_t QueueIndex = 0; while (ToVisit.size() != 0) { BasicBlock *BB = ToVisit.front(); ToVisit.pop(); if (!CanBeVisited(BB)) { ToVisit.push(BB); + assert(QueueIndex < ToVisit.size() && + "No valid candidate in the queue. Is the graph reducible?"); + QueueIndex++; continue; } + QueueIndex = 0; size_t Rank = GetNodeRank(BB); OrderInfo Info = {Rank, BlockToOrder.size()}; BlockToOrder.emplace(BB, Info); diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h index da0e8769cac1b..d218dbd850dc7 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.h +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h @@ -41,6 +41,8 @@ class SPIRVSubtarget; // ignores back-edges. The cycle is visited from the entry in the same // topological-like ordering. // +// Note: this visitor REQUIRES a reducible graph. +// // This means once we visit a node, we know all the possible ancestors have been // visited. // @@ -84,10 +86,11 @@ class PartialOrderingVisitor { // Visits |BB| with the current rank being |Rank|. size_t visit(BasicBlock *BB, size_t Rank); - size_t GetNodeRank(BasicBlock *BB) const; bool CanBeVisited(BasicBlock *BB) const; public: + size_t GetNodeRank(BasicBlock *BB) const; + // Build the visitor to operate on the function F. PartialOrderingVisitor(Function &F); diff --git a/llvm/test/CodeGen/SPIRV/structurizer/cf.switch.ifstmt.simple2.ll b/llvm/test/CodeGen/SPIRV/structurizer/cf.switch.ifstmt.simple2.ll index d7b453aac56bc..58f0f1c6bf053 100644 --- a/llvm/test/CodeGen/SPIRV/structurizer/cf.switch.ifstmt.simple2.ll +++ b/llvm/test/CodeGen/SPIRV/structurizer/cf.switch.ifstmt.simple2.ll @@ -19,10 +19,10 @@ ; b += 2; ; break; ; case 3: -; { -; b += 3; -; break; -; } +; { +; b += 3; +; break; +; } ; case t: ; b += t; ; case 4: @@ -30,10 +30,10 @@ ; b += 5; ; break; ; case 6: { -; case 7: -; break;} +; case 7: +; break;} ; default: -; break; +; break; ; } ; ; return a + b + c; diff --git a/llvm/unittests/Target/SPIRV/CMakeLists.txt b/llvm/unittests/Target/SPIRV/CMakeLists.txt index 2af36225c5f20..af7d9395605d0 100644 --- a/llvm/unittests/Target/SPIRV/CMakeLists.txt +++ b/llvm/unittests/Target/SPIRV/CMakeLists.txt @@ -16,5 +16,6 @@ set(LLVM_LINK_COMPONENTS add_llvm_target_unittest(SPIRVTests SPIRVConvergenceRegionAnalysisTests.cpp SPIRVSortBlocksTests.cpp + SPIRVPartialOrderingVisitorTests.cpp SPIRVAPITest.cpp ) diff --git a/llvm/unittests/Target/SPIRV/SPIRVPartialOrderingVisitorTests.cpp b/llvm/unittests/Target/SPIRV/SPIRVPartialOrderingVisitorTests.cpp new file mode 100644 index 0000000000000..95d6c3f5a2940 --- /dev/null +++ b/llvm/unittests/Target/SPIRV/SPIRVPartialOrderingVisitorTests.cpp @@ -0,0 +1,342 @@ +//===- SPIRVPartialOrderingVisitorTests.cpp ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "SPIRVUtils.h" +#include "llvm/Analysis/DominanceFrontier.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassInstrumentation.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/TypedPointerType.h" +#include "llvm/Support/SourceMgr.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include + +using namespace llvm; +using namespace llvm::SPIRV; + +class SPIRVPartialOrderingVisitorTest : public testing::Test { +protected: + void TearDown() override { M.reset(); } + + void run(StringRef Assembly) { + assert(M == nullptr && + "Calling runAnalysis multiple times is unsafe. See getAnalysis()."); + + SMDiagnostic Error; + M = parseAssemblyString(Assembly, Error, Context); + assert(M && "Bad assembly. Bad test?"); + + llvm::Function *F = M->getFunction("main"); + Visitor = std::make_unique(*F); + } + + void + checkBasicBlockRank(std::vector> &&Expected) { + llvm::Function *F = M->getFunction("main"); + auto It = Expected.begin(); + Visitor->partialOrderVisit(*F->begin(), [&](BasicBlock *BB) { + const auto &[Name, Rank] = *It; + EXPECT_TRUE(It != Expected.end()) + << "Unexpected block \"" << BB->getName() << " visited."; + EXPECT_TRUE(BB->getName() == Name) + << "Error: expected block \"" << Name << "\" got \"" << BB->getName() + << "\""; + EXPECT_EQ(Rank, Visitor->GetNodeRank(BB)) + << "Bad rank for BB \"" << BB->getName() << "\""; + It++; + return true; + }); + ASSERT_TRUE(It == Expected.end()) + << "Expected block \"" << It->first + << "\" but reached the end of the function instead."; + } + +protected: + LLVMContext Context; + std::unique_ptr M; + std::unique_ptr Visitor; +}; + +TEST_F(SPIRVPartialOrderingVisitorTest, EmptyFunction) { + StringRef Assembly = R"( + define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" { + ret void + } + )"; + + run(Assembly); + checkBasicBlockRank({{"", 0}}); +} + +TEST_F(SPIRVPartialOrderingVisitorTest, BasicBlockSwap) { + StringRef Assembly = R"( + define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" { + entry: + br label %middle + exit: + ret void + middle: + br label %exit + } + )"; + + run(Assembly); + checkBasicBlockRank({{"entry", 0}, {"middle", 1}, {"exit", 2}}); +} + +// Skip condition: +// +-> A -+ +// entry -+ +-> C +// +------+ +TEST_F(SPIRVPartialOrderingVisitorTest, SkipCondition) { + StringRef Assembly = R"( + define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" { + entry: + %1 = icmp ne i32 0, 0 + br i1 %1, label %c, label %a + c: + ret void + a: + br label %c + } + )"; + + run(Assembly); + checkBasicBlockRank({{"entry", 0}, {"a", 1}, {"c", 2}}); +} + +// Simple loop: +// entry -> header <-----------------+ +// | `-> body -> continue -+ +// `-> end +TEST_F(SPIRVPartialOrderingVisitorTest, LoopOrdering) { + StringRef Assembly = R"( + define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" { + entry: + %1 = icmp ne i32 0, 0 + br label %header + end: + ret void + body: + br label %continue + continue: + br label %header + header: + br i1 %1, label %body, label %end + } + )"; + + run(Assembly); + checkBasicBlockRank( + {{"entry", 0}, {"header", 1}, {"body", 2}, {"continue", 3}, {"end", 4}}); +} + +// Diamond condition: +// +-> A -+ +// entry -+ +-> C +// +-> B -+ +// +// A and B order can be flipped with no effect, but it must be remain +// deterministic/stable. +TEST_F(SPIRVPartialOrderingVisitorTest, DiamondCondition) { + StringRef Assembly = R"( + define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" { + entry: + %1 = icmp ne i32 0, 0 + br i1 %1, label %a, label %b + c: + ret void + b: + br label %c + a: + br label %c + } + )"; + + run(Assembly); + checkBasicBlockRank({{"entry", 0}, {"a", 1}, {"b", 1}, {"c", 2}}); +} + +// Crossing conditions: +// +------+ +-> C -+ +// +-> A -+ | | | +// entry -+ +--_|_-+ +-> E +// +-> B -+ | | +// +------+----> D -+ +// +// A & B have the same rank. +// C & D have the same rank, but are after A & B. +// E if the last block. +TEST_F(SPIRVPartialOrderingVisitorTest, CrossingCondition) { + StringRef Assembly = R"( + define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" { + entry: + %1 = icmp ne i32 0, 0 + br i1 %1, label %a, label %b + e: + ret void + c: + br label %e + b: + br i1 %1, label %d, label %c + d: + br label %e + a: + br i1 %1, label %c, label %d + } + )"; + + run(Assembly); + checkBasicBlockRank( + {{"entry", 0}, {"a", 1}, {"b", 1}, {"c", 2}, {"d", 2}, {"e", 3}}); +} + +TEST_F(SPIRVPartialOrderingVisitorTest, LoopDiamond) { + StringRef Assembly = R"( + define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" { + entry: + %1 = icmp ne i32 0, 0 + br label %header + header: + br i1 %1, label %body, label %end + body: + br i1 %1, label %inside_a, label %break + inside_a: + br label %inside_b + inside_b: + br i1 %1, label %inside_c, label %inside_d + inside_c: + br label %continue + inside_d: + br label %continue + break: + br label %end + continue: + br label %header + end: + ret void + } + )"; + + run(Assembly); + checkBasicBlockRank({{"entry", 0}, + {"header", 1}, + {"body", 2}, + {"inside_a", 3}, + {"inside_b", 4}, + {"inside_c", 5}, + {"inside_d", 5}, + {"continue", 6}, + {"break", 7}, + {"end", 8}}); +} + +TEST_F(SPIRVPartialOrderingVisitorTest, LoopNested) { + StringRef Assembly = R"( + define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" { + entry: + %1 = icmp ne i32 0, 0 + br label %a + a: + br i1 %1, label %h, label %b + b: + br label %c + c: + br i1 %1, label %d, label %e + d: + br label %g + e: + br label %f + f: + br label %c + g: + br label %a + h: + ret void + } + )"; + + run(Assembly); + checkBasicBlockRank({{"entry", 0}, + {"a", 1}, + {"b", 2}, + {"c", 3}, + {"e", 4}, + {"f", 5}, + {"d", 6}, + {"g", 7}, + {"h", 8}}); +} + +TEST_F(SPIRVPartialOrderingVisitorTest, IfNested) { + StringRef Assembly = R"( + define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" { + entry: + br i1 true, label %a, label %d + a: + br i1 true, label %b, label %c + b: + br label %c + c: + br label %j + d: + br i1 true, label %e, label %f + e: + br label %i + f: + br i1 true, label %g, label %h + g: + br label %h + h: + br label %i + i: + br label %j + j: + ret void + } + )"; + run(Assembly); + checkBasicBlockRank({{"entry", 0}, + {"a", 1}, + {"d", 1}, + {"b", 2}, + {"e", 2}, + {"f", 2}, + {"c", 3}, + {"g", 3}, + {"h", 4}, + {"i", 5}, + {"j", 6}}); +} + +TEST_F(SPIRVPartialOrderingVisitorTest, CheckDeathIrreducible) { + StringRef Assembly = R"( + define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" { + entry: + %1 = icmp ne i32 0, 0 + br label %a + b: + br i1 %1, label %a, label %c + c: + br label %b + a: + br i1 %1, label %b, label %c + } + )"; + + ASSERT_DEATH( + { run(Assembly); }, + "No valid candidate in the queue. Is the graph reducible?"); +}