diff --git a/llvm/include/llvm/Support/TrailingObjects.h b/llvm/include/llvm/Support/TrailingObjects.h index 07cf08df45a6a..f25f2311a81a4 100644 --- a/llvm/include/llvm/Support/TrailingObjects.h +++ b/llvm/include/llvm/Support/TrailingObjects.h @@ -46,6 +46,7 @@ #ifndef LLVM_SUPPORT_TRAILINGOBJECTS_H #define LLVM_SUPPORT_TRAILINGOBJECTS_H +#include "llvm/ADT/ArrayRef.h" #include "llvm/Support/Alignment.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/MathExtras.h" @@ -301,6 +302,41 @@ class TrailingObjects : private trailing_objects_internal::TrailingObjectsImpl< static_cast(this), TrailingObjectsBase::OverloadToken()); } + // getTrailingObjects() specialization for a single trailing type. + using FirstTrailingType = + typename std::tuple_element_t<0, std::tuple>; + + const FirstTrailingType *getTrailingObjects() const { + static_assert(sizeof...(TrailingTys) == 1, + "Can use non-templated getTrailingObjects() only when there " + "is a single trailing type"); + return getTrailingObjects(); + } + + FirstTrailingType *getTrailingObjects() { + static_assert(sizeof...(TrailingTys) == 1, + "Can use non-templated getTrailingObjects() only when there " + "is a single trailing type"); + return getTrailingObjects(); + } + + // Functions that return the trailing objects as ArrayRefs. + template MutableArrayRef getTrailingObjects(size_t N) { + return MutableArrayRef(getTrailingObjects(), N); + } + + template ArrayRef getTrailingObjects(size_t N) const { + return ArrayRef(getTrailingObjects(), N); + } + + MutableArrayRef getTrailingObjects(size_t N) { + return MutableArrayRef(getTrailingObjects(), N); + } + + ArrayRef getTrailingObjects(size_t N) const { + return ArrayRef(getTrailingObjects(), N); + } + /// Returns the size of the trailing data, if an object were /// allocated with the given counts (The counts are in the same order /// as the template arguments). This does not include the size of the diff --git a/llvm/unittests/Support/TrailingObjectsTest.cpp b/llvm/unittests/Support/TrailingObjectsTest.cpp index e36979e75d7f7..6f9d7bda7fe5a 100644 --- a/llvm/unittests/Support/TrailingObjectsTest.cpp +++ b/llvm/unittests/Support/TrailingObjectsTest.cpp @@ -26,7 +26,9 @@ class Class1 final : protected TrailingObjects { size_t numTrailingObjects(OverloadToken) const { return NumShorts; } Class1(ArrayRef ShortArray) : NumShorts(ShortArray.size()) { - llvm::copy(ShortArray, getTrailingObjects()); + // This tests the non-templated getTrailingObjects() that returns a pointer + // when using a single trailing type. + llvm::copy(ShortArray, getTrailingObjects()); } public: @@ -36,7 +38,8 @@ class Class1 final : protected TrailingObjects { } void operator delete(void *Ptr) { ::operator delete(Ptr); } - short get(unsigned Num) const { return getTrailingObjects()[Num]; } + // This indexes into the ArrayRef<> returned by `getTrailingObjects`. + short get(unsigned Num) const { return getTrailingObjects(NumShorts)[Num]; } unsigned numShorts() const { return NumShorts; } @@ -128,6 +131,9 @@ TEST(TrailingObjects, OneArg) { EXPECT_EQ(C->getTrailingObjects(), reinterpret_cast(C + 1)); EXPECT_EQ(C->get(0), 1); EXPECT_EQ(C->get(2), 3); + + EXPECT_EQ(C->getTrailingObjects(), C->getTrailingObjects()); + delete C; }