diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h index 4a91b061dd3b7..5b20d6bd38262 100644 --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -1692,6 +1692,28 @@ template auto accumulate(R &&Range, E &&Init) { std::forward(Init)); } +/// Wrapper for std::accumulate with a binary operator. +template +auto accumulate(R &&Range, E &&Init, BinaryOp &&Op) { + return std::accumulate(adl_begin(Range), adl_end(Range), + std::forward(Init), std::forward(Op)); +} + +/// Returns the sum of all values in `Range` with `Init` initial value. +/// The default initial value is 0. +template > +auto sum_of(R &&Range, E Init = E{0}) { + return accumulate(std::forward(Range), std::move(Init)); +} + +/// Returns the product of all values in `Range` with `Init` initial value. +/// The default initial value is 1. +template > +auto product_of(R &&Range, E Init = E{1}) { + return accumulate(std::forward(Range), std::move(Init), + std::multiplies<>{}); +} + /// Provide wrappers to std::for_each which take ranges instead of having to /// pass begin/end explicitly. template diff --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp index 5020acda95b0b..474699835a7dc 100644 --- a/llvm/unittests/ADT/STLExtrasTest.cpp +++ b/llvm/unittests/ADT/STLExtrasTest.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -1658,6 +1659,54 @@ TEST(STLExtrasTest, Accumulate) { EXPECT_EQ(accumulate(V1, 10), std::accumulate(V1.begin(), V1.end(), 10)); EXPECT_EQ(accumulate(drop_begin(V1), 7), std::accumulate(V1.begin() + 1, V1.end(), 7)); + + EXPECT_EQ(accumulate(V1, 2, std::multiplies<>{}), 240); +} + +TEST(STLExtrasTest, SumOf) { + EXPECT_EQ(sum_of(std::vector()), 0); + EXPECT_EQ(sum_of(std::vector(), 1), 1); + std::vector V1 = {1, 2, 3, 4, 5}; + static_assert(std::is_same_v); + static_assert(std::is_same_v); + EXPECT_EQ(sum_of(V1), 15); + EXPECT_EQ(sum_of(V1, 1), 16); + + std::vector V2 = {1.0f, 2.0f, 4.0f}; + static_assert(std::is_same_v); + static_assert(std::is_same_v); + static_assert(std::is_same_v); + EXPECT_EQ(sum_of(V2), 7.0f); + EXPECT_EQ(sum_of(V2, 1.0f), 8.0f); + + // Make sure that for a const argument the return value is non-const. + const std::vector V3 = {1.0f, 2.0f}; + static_assert(std::is_same_v); + EXPECT_EQ(sum_of(V3), 3.0f); +} + +TEST(STLExtrasTest, ProductOf) { + EXPECT_EQ(product_of(std::vector()), 1); + EXPECT_EQ(product_of(std::vector(), 0), 0); + EXPECT_EQ(product_of(std::vector(), 1), 1); + std::vector V1 = {1, 2, 3, 4, 5}; + static_assert(std::is_same_v); + static_assert(std::is_same_v); + EXPECT_EQ(product_of(V1), 120); + EXPECT_EQ(product_of(V1, 1), 120); + EXPECT_EQ(product_of(V1, 2), 240); + + std::vector V2 = {1.0f, 2.0f, 4.0f}; + static_assert(std::is_same_v); + static_assert(std::is_same_v); + static_assert(std::is_same_v); + EXPECT_EQ(product_of(V2), 8.0f); + EXPECT_EQ(product_of(V2, 4.0f), 32.0f); + + // Make sure that for a const argument the return value is non-const. + const std::vector V3 = {1.0f, 2.0f}; + static_assert(std::is_same_v); + EXPECT_EQ(product_of(V3), 2.0f); } struct Foo;