Skip to content

Commit 454ef02

Browse files
authored
[ADT] Add sum_of and product_of accumulate wrappers (#162129)
Also extend the `accumulate` wrapper to accept a binary operator. The goal is to the most common usage of `std::accumulate` across the codebase -- calculating either the sum of or the product of all values.
1 parent 64574d3 commit 454ef02

File tree

2 files changed

+71
-0
lines changed

2 files changed

+71
-0
lines changed

llvm/include/llvm/ADT/STLExtras.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1692,6 +1692,28 @@ template <typename R, typename E> auto accumulate(R &&Range, E &&Init) {
16921692
std::forward<E>(Init));
16931693
}
16941694

1695+
/// Wrapper for std::accumulate with a binary operator.
1696+
template <typename R, typename E, typename BinaryOp>
1697+
auto accumulate(R &&Range, E &&Init, BinaryOp &&Op) {
1698+
return std::accumulate(adl_begin(Range), adl_end(Range),
1699+
std::forward<E>(Init), std::forward<BinaryOp>(Op));
1700+
}
1701+
1702+
/// Returns the sum of all values in `Range` with `Init` initial value.
1703+
/// The default initial value is 0.
1704+
template <typename R, typename E = detail::ValueOfRange<R>>
1705+
auto sum_of(R &&Range, E Init = E{0}) {
1706+
return accumulate(std::forward<R>(Range), std::move(Init));
1707+
}
1708+
1709+
/// Returns the product of all values in `Range` with `Init` initial value.
1710+
/// The default initial value is 1.
1711+
template <typename R, typename E = detail::ValueOfRange<R>>
1712+
auto product_of(R &&Range, E Init = E{1}) {
1713+
return accumulate(std::forward<R>(Range), std::move(Init),
1714+
std::multiplies<>{});
1715+
}
1716+
16951717
/// Provide wrappers to std::for_each which take ranges instead of having to
16961718
/// pass begin/end explicitly.
16971719
template <typename R, typename UnaryFunction>

llvm/unittests/ADT/STLExtrasTest.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <array>
1515
#include <climits>
1616
#include <cstddef>
17+
#include <functional>
1718
#include <initializer_list>
1819
#include <iterator>
1920
#include <list>
@@ -1658,6 +1659,54 @@ TEST(STLExtrasTest, Accumulate) {
16581659
EXPECT_EQ(accumulate(V1, 10), std::accumulate(V1.begin(), V1.end(), 10));
16591660
EXPECT_EQ(accumulate(drop_begin(V1), 7),
16601661
std::accumulate(V1.begin() + 1, V1.end(), 7));
1662+
1663+
EXPECT_EQ(accumulate(V1, 2, std::multiplies<>{}), 240);
1664+
}
1665+
1666+
TEST(STLExtrasTest, SumOf) {
1667+
EXPECT_EQ(sum_of(std::vector<int>()), 0);
1668+
EXPECT_EQ(sum_of(std::vector<int>(), 1), 1);
1669+
std::vector<int> V1 = {1, 2, 3, 4, 5};
1670+
static_assert(std::is_same_v<decltype(sum_of(V1)), int>);
1671+
static_assert(std::is_same_v<decltype(sum_of(V1, 1)), int>);
1672+
EXPECT_EQ(sum_of(V1), 15);
1673+
EXPECT_EQ(sum_of(V1, 1), 16);
1674+
1675+
std::vector<float> V2 = {1.0f, 2.0f, 4.0f};
1676+
static_assert(std::is_same_v<decltype(sum_of(V2)), float>);
1677+
static_assert(std::is_same_v<decltype(sum_of(V2), 1.0f), float>);
1678+
static_assert(std::is_same_v<decltype(sum_of(V2), 1.0), double>);
1679+
EXPECT_EQ(sum_of(V2), 7.0f);
1680+
EXPECT_EQ(sum_of(V2, 1.0f), 8.0f);
1681+
1682+
// Make sure that for a const argument the return value is non-const.
1683+
const std::vector<float> V3 = {1.0f, 2.0f};
1684+
static_assert(std::is_same_v<decltype(sum_of(V3)), float>);
1685+
EXPECT_EQ(sum_of(V3), 3.0f);
1686+
}
1687+
1688+
TEST(STLExtrasTest, ProductOf) {
1689+
EXPECT_EQ(product_of(std::vector<int>()), 1);
1690+
EXPECT_EQ(product_of(std::vector<int>(), 0), 0);
1691+
EXPECT_EQ(product_of(std::vector<int>(), 1), 1);
1692+
std::vector<int> V1 = {1, 2, 3, 4, 5};
1693+
static_assert(std::is_same_v<decltype(product_of(V1)), int>);
1694+
static_assert(std::is_same_v<decltype(product_of(V1, 1)), int>);
1695+
EXPECT_EQ(product_of(V1), 120);
1696+
EXPECT_EQ(product_of(V1, 1), 120);
1697+
EXPECT_EQ(product_of(V1, 2), 240);
1698+
1699+
std::vector<float> V2 = {1.0f, 2.0f, 4.0f};
1700+
static_assert(std::is_same_v<decltype(product_of(V2)), float>);
1701+
static_assert(std::is_same_v<decltype(product_of(V2), 1.0f), float>);
1702+
static_assert(std::is_same_v<decltype(product_of(V2), 1.0), double>);
1703+
EXPECT_EQ(product_of(V2), 8.0f);
1704+
EXPECT_EQ(product_of(V2, 4.0f), 32.0f);
1705+
1706+
// Make sure that for a const argument the return value is non-const.
1707+
const std::vector<float> V3 = {1.0f, 2.0f};
1708+
static_assert(std::is_same_v<decltype(product_of(V3)), float>);
1709+
EXPECT_EQ(product_of(V3), 2.0f);
16611710
}
16621711

16631712
struct Foo;

0 commit comments

Comments
 (0)