diff --git a/llvm/include/llvm/Support/Casting.h b/llvm/include/llvm/Support/Casting.h index 6f6df2e9703ea..a6435a2562a2b 100644 --- a/llvm/include/llvm/Support/Casting.h +++ b/llvm/include/llvm/Support/Casting.h @@ -816,6 +816,42 @@ template struct IsaAndPresentCheckPredicate { return isa_and_present(Val); } }; + +//===----------------------------------------------------------------------===// +// Casting Function Objects +//===----------------------------------------------------------------------===// + +/// Usable in generic algorithms like map_range +template struct StaticCastFunc { + template decltype(auto) operator()(T &&Val) const { + return static_cast(Val); + } +}; + +template struct DynCastFunc { + template decltype(auto) operator()(T &&Val) const { + return dyn_cast(Val); + } +}; + +template struct CastFunc { + template decltype(auto) operator()(T &&Val) const { + return cast(Val); + } +}; + +template struct CastIfPresentFunc { + template decltype(auto) operator()(T &&Val) const { + return cast_if_present(Val); + } +}; + +template struct DynCastIfPresentFunc { + template decltype(auto) operator()(T &&Val) const { + return dyn_cast_if_present(Val); + } +}; + } // namespace detail /// Function object wrapper for the `llvm::isa` type check. The function call @@ -841,6 +877,20 @@ template inline constexpr detail::IsaAndPresentCheckPredicate IsaAndPresentPred{}; +/// Function objects corresponding to the Cast types defined above. +template +inline constexpr detail::StaticCastFunc StaticCastTo{}; + +template inline constexpr detail::CastFunc CastTo{}; + +template +inline constexpr detail::CastIfPresentFunc CastIfPresentTo{}; + +template +inline constexpr detail::DynCastIfPresentFunc DynCastIfPresentTo{}; + +template inline constexpr detail::DynCastFunc DynCastTo{}; + } // end namespace llvm #endif // LLVM_SUPPORT_CASTING_H diff --git a/llvm/unittests/Support/Casting.cpp b/llvm/unittests/Support/Casting.cpp index 790675083614b..0df8b9fcab452 100644 --- a/llvm/unittests/Support/Casting.cpp +++ b/llvm/unittests/Support/Casting.cpp @@ -561,6 +561,47 @@ TEST(CastingTest, assertion_check_unique_ptr) { << "Invalid cast of const ref did not cause an abort()"; } +TEST(Casting, StaticCastPredicate) { + uint32_t Value = 1; + + static_assert( + std::is_same_v(Value)), uint64_t>); +} + +TEST(Casting, LLVMRTTIPredicates) { + struct Base { + enum Kind { BK_Base, BK_Derived }; + const Kind K; + Base(Kind K = BK_Base) : K(K) {} + Kind getKind() const { return K; } + virtual ~Base() = default; + }; + + struct Derived : Base { + Derived() : Base(BK_Derived) {} + static bool classof(const Base *B) { return B->getKind() == BK_Derived; } + bool Field = false; + }; + + Base B; + Derived D; + Base *BD = &D; + Base *Null = nullptr; + + // Pointers. + EXPECT_EQ(DynCastTo(BD), &D); + EXPECT_EQ(CastTo(BD), &D); + EXPECT_EQ(DynCastTo(&B), nullptr); + EXPECT_EQ(CastIfPresentTo(BD), &D); + EXPECT_EQ(CastIfPresentTo(Null), nullptr); + EXPECT_EQ(DynCastIfPresentTo(BD), &D); + EXPECT_EQ(DynCastIfPresentTo(Null), nullptr); + + Base &R = D; + CastTo(R).Field = true; + EXPECT_TRUE(D.Field); +} + } // end namespace assertion_checks #endif } // end namespace