Skip to content

Commit 186e8eb

Browse files
Marsella8Pietro Max Marsella
andauthored
added recurse_n (flexflow#1563)
* added recurse_n * fix --------- Co-authored-by: Pietro Max Marsella <marsella@stanford.edu>
1 parent 93298ed commit 186e8eb

File tree

3 files changed

+75
-0
lines changed

3 files changed

+75
-0
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_RECURSE_N_H
2+
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_RECURSE_N_H
3+
4+
#include "utils/exception.h"
5+
6+
namespace FlexFlow {
7+
8+
/**
9+
* @brief
10+
* Applies function `f` to value `initial_value` n times recursively.
11+
*
12+
* @example
13+
* auto add_three = [](int x) { return x + 3; };
14+
* int result = recurse_n(add_three, 3, 5);
15+
* result -> f(f(f(5))) = ((5+3)+3)+3 = 14
16+
*
17+
* @throws RuntimeError if n is negative
18+
*/
19+
template <typename F, typename T>
20+
T recurse_n(F const &f, int n, T const &initial_value) {
21+
if (n < 0) {
22+
throw mk_runtime_error(
23+
fmt::format("Supplied n={} should be non-negative", n));
24+
}
25+
T t = initial_value;
26+
for (int i = 0; i < n; i++) {
27+
t = f(t);
28+
}
29+
return t;
30+
}
31+
32+
} // namespace FlexFlow
33+
34+
#endif
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#include "utils/containers/recurse_n.h"
2+
#include "utils/archetypes/value_type.h"
3+
#include <functional>
4+
5+
namespace FlexFlow {
6+
7+
using T = value_type<0>;
8+
using F = std::function<T(T)>; // F :: T -> T
9+
10+
template T recurse_n(F const &f, int n, T const &initial_value);
11+
12+
} // namespace FlexFlow
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#include "utils/containers/recurse_n.h"
2+
#include <doctest/doctest.h>
3+
#include <string>
4+
5+
using namespace FlexFlow;
6+
7+
TEST_SUITE(FF_TEST_SUITE) {
8+
TEST_CASE("recurse_n") {
9+
auto append_bar = [](std::string const &x) {
10+
return x + std::string("Bar");
11+
};
12+
13+
SUBCASE("n = 0") {
14+
std::string result = recurse_n(append_bar, 0, std::string("Foo"));
15+
std::string correct = "Foo";
16+
CHECK(result == correct);
17+
}
18+
19+
SUBCASE("n = 3") {
20+
std::string result = recurse_n(append_bar, 3, std::string("Foo"));
21+
std::string correct = "FooBarBarBar";
22+
CHECK(result == correct);
23+
}
24+
25+
SUBCASE("n < 0") {
26+
CHECK_THROWS(recurse_n(append_bar, -1, std::string("Foo")));
27+
}
28+
}
29+
}

0 commit comments

Comments
 (0)