Skip to content

Commit b241cc9

Browse files
authored
[ADT] Fix llvm::concat_iterator for ValueT == common_base_class * (#144744)
Fix `llvm::concat_iterator` for the case of `ValueT` being a pointer to a common base class to which the result of dereferencing any iterator in `ItersT` can be casted to.
1 parent def2020 commit b241cc9

File tree

2 files changed

+85
-46
lines changed

2 files changed

+85
-46
lines changed

llvm/include/llvm/ADT/STLExtras.h

Lines changed: 37 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,13 @@ using is_one_of = std::disjunction<std::is_same<T, Ts>...>;
114114
template <typename T, typename... Ts>
115115
using are_base_of = std::conjunction<std::is_base_of<T, Ts>...>;
116116

117+
/// traits class for checking whether type `T` is same as all other types in
118+
/// `Ts`.
119+
template <typename T = void, typename... Ts>
120+
using all_types_equal = std::conjunction<std::is_same<T, Ts>...>;
121+
template <typename T = void, typename... Ts>
122+
constexpr bool all_types_equal_v = all_types_equal<T, Ts...>::value;
123+
117124
/// Determine if all types in Ts are distinct.
118125
///
119126
/// Useful to statically assert when Ts is intended to describe a non-multi set
@@ -996,13 +1003,17 @@ class concat_iterator
9961003

9971004
static constexpr bool ReturnsByValue =
9981005
!(std::is_reference_v<decltype(*std::declval<IterTs>())> && ...);
999-
1006+
static constexpr bool ReturnsConvertibleType =
1007+
!all_types_equal_v<
1008+
std::remove_cv_t<ValueT>,
1009+
remove_cvref_t<decltype(*std::declval<IterTs>())>...> &&
1010+
(std::is_convertible_v<decltype(*std::declval<IterTs>()), ValueT> && ...);
1011+
1012+
// Cannot return a reference type if a conversion takes place, provided that
1013+
// the result of dereferencing all `IterTs...` is convertible to `ValueT`.
10001014
using reference_type =
1001-
typename std::conditional_t<ReturnsByValue, ValueT, ValueT &>;
1002-
1003-
using handle_type =
1004-
typename std::conditional_t<ReturnsByValue, std::optional<ValueT>,
1005-
ValueT *>;
1015+
std::conditional_t<ReturnsByValue || ReturnsConvertibleType, ValueT,
1016+
ValueT &>;
10061017

10071018
/// We store both the current and end iterators for each concatenated
10081019
/// sequence in a tuple of pairs.
@@ -1013,66 +1024,46 @@ class concat_iterator
10131024
std::tuple<IterTs...> Begins;
10141025
std::tuple<IterTs...> Ends;
10151026

1016-
/// Attempts to increment a specific iterator.
1017-
///
1018-
/// Returns true if it was able to increment the iterator. Returns false if
1019-
/// the iterator is already at the end iterator.
1020-
template <size_t Index> bool incrementHelper() {
1027+
/// Attempts to increment the `Index`-th iterator. If the iterator is already
1028+
/// at end, recurse over iterators in `Others...`.
1029+
template <size_t Index, size_t... Others> void incrementImpl() {
10211030
auto &Begin = std::get<Index>(Begins);
10221031
auto &End = std::get<Index>(Ends);
1023-
if (Begin == End)
1024-
return false;
1025-
1032+
if (Begin == End) {
1033+
if constexpr (sizeof...(Others) != 0)
1034+
return incrementImpl<Others...>();
1035+
llvm_unreachable("Attempted to increment an end concat iterator!");
1036+
}
10261037
++Begin;
1027-
return true;
10281038
}
10291039

10301040
/// Increments the first non-end iterator.
10311041
///
10321042
/// It is an error to call this with all iterators at the end.
10331043
template <size_t... Ns> void increment(std::index_sequence<Ns...>) {
1034-
// Build a sequence of functions to increment each iterator if possible.
1035-
bool (concat_iterator::*IncrementHelperFns[])() = {
1036-
&concat_iterator::incrementHelper<Ns>...};
1037-
1038-
// Loop over them, and stop as soon as we succeed at incrementing one.
1039-
for (auto &IncrementHelperFn : IncrementHelperFns)
1040-
if ((this->*IncrementHelperFn)())
1041-
return;
1042-
1043-
llvm_unreachable("Attempted to increment an end concat iterator!");
1044+
incrementImpl<Ns...>();
10441045
}
10451046

1046-
/// Returns null if the specified iterator is at the end. Otherwise,
1047-
/// dereferences the iterator and returns the address of the resulting
1048-
/// reference.
1049-
template <size_t Index> handle_type getHelper() const {
1047+
/// Dereferences the `Index`-th iterator and returns the resulting reference.
1048+
/// If `Index` is at end, recurse over iterators in `Others...`.
1049+
template <size_t Index, size_t... Others> reference_type getImpl() const {
10501050
auto &Begin = std::get<Index>(Begins);
10511051
auto &End = std::get<Index>(Ends);
1052-
if (Begin == End)
1053-
return {};
1054-
1055-
if constexpr (ReturnsByValue)
1056-
return *Begin;
1057-
else
1058-
return &*Begin;
1052+
if (Begin == End) {
1053+
if constexpr (sizeof...(Others) != 0)
1054+
return getImpl<Others...>();
1055+
llvm_unreachable(
1056+
"Attempted to get a pointer from an end concat iterator!");
1057+
}
1058+
return *Begin;
10591059
}
10601060

10611061
/// Finds the first non-end iterator, dereferences, and returns the resulting
10621062
/// reference.
10631063
///
10641064
/// It is an error to call this with all iterators at the end.
10651065
template <size_t... Ns> reference_type get(std::index_sequence<Ns...>) const {
1066-
// Build a sequence of functions to get from iterator if possible.
1067-
handle_type (concat_iterator::*GetHelperFns[])()
1068-
const = {&concat_iterator::getHelper<Ns>...};
1069-
1070-
// Loop over them, and return the first result we find.
1071-
for (auto &GetHelperFn : GetHelperFns)
1072-
if (auto P = (this->*GetHelperFn)())
1073-
return *P;
1074-
1075-
llvm_unreachable("Attempted to get a pointer from an end concat iterator!");
1066+
return getImpl<Ns...>();
10761067
}
10771068

10781069
public:

llvm/unittests/ADT/STLExtrasTest.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,8 @@ struct some_struct {
398398
std::string swap_val;
399399
};
400400

401+
struct derives_from_some_struct : some_struct {};
402+
401403
std::vector<int>::const_iterator begin(const some_struct &s) {
402404
return s.data.begin();
403405
}
@@ -500,6 +502,15 @@ TEST(STLExtrasTest, ToVector) {
500502
}
501503
}
502504

505+
TEST(STLExtrasTest, AllTypesEqual) {
506+
static_assert(all_types_equal_v<>);
507+
static_assert(all_types_equal_v<int>);
508+
static_assert(all_types_equal_v<int, int, int>);
509+
510+
static_assert(!all_types_equal_v<int, int, unsigned int>);
511+
static_assert(!all_types_equal_v<int, int, float>);
512+
}
513+
503514
TEST(STLExtrasTest, ConcatRange) {
504515
std::vector<int> Expected = {1, 2, 3, 4, 5, 6, 7, 8};
505516
std::vector<int> Test;
@@ -532,6 +543,43 @@ TEST(STLExtrasTest, ConcatRangeADL) {
532543
EXPECT_THAT(concat<const int>(S0, S1), ElementsAre(1, 2, 3, 4));
533544
}
534545

546+
TEST(STLExtrasTest, ConcatRangePtrToSameClass) {
547+
some_namespace::some_struct S0{};
548+
some_namespace::some_struct S1{};
549+
SmallVector<some_namespace::some_struct *> V0{&S0};
550+
SmallVector<some_namespace::some_struct *> V1{&S1, &S1};
551+
552+
// Dereferencing all iterators yields `some_namespace::some_struct *&`; no
553+
// conversion takes place, `reference_type` is
554+
// `some_namespace::some_struct *&`.
555+
auto C = concat<some_namespace::some_struct *>(V0, V1);
556+
static_assert(
557+
std::is_same_v<decltype(*C.begin()), some_namespace::some_struct *&>);
558+
EXPECT_THAT(C, ElementsAre(&S0, &S1, &S1));
559+
// `reference_type` should still allow container modification.
560+
for (auto &i : C)
561+
if (i == &S0)
562+
i = nullptr;
563+
EXPECT_THAT(C, ElementsAre(nullptr, &S1, &S1));
564+
}
565+
566+
TEST(STLExtrasTest, ConcatRangePtrToDerivedClass) {
567+
some_namespace::some_struct S0{};
568+
some_namespace::derives_from_some_struct S1{};
569+
SmallVector<some_namespace::some_struct *> V0{&S0};
570+
SmallVector<some_namespace::derives_from_some_struct *> V1{&S1, &S1};
571+
572+
// Dereferencing all iterators yields different (but convertible types);
573+
// conversion takes place, `reference_type` is
574+
// `some_namespace::some_struct *`.
575+
auto C = concat<some_namespace::some_struct *>(V0, V1);
576+
static_assert(
577+
std::is_same_v<decltype(*C.begin()), some_namespace::some_struct *>);
578+
EXPECT_THAT(C,
579+
ElementsAre(&S0, static_cast<some_namespace::some_struct *>(&S1),
580+
static_cast<some_namespace::some_struct *>(&S1)));
581+
}
582+
535583
TEST(STLExtrasTest, MakeFirstSecondRangeADL) {
536584
// Make sure that we use the `begin`/`end` functions from `some_namespace`,
537585
// using ADL.

0 commit comments

Comments
 (0)