File tree Expand file tree Collapse file tree 3 files changed +75
-0
lines changed
test/src/utils/containers Expand file tree Collapse file tree 3 files changed +75
-0
lines changed Original file line number Diff line number Diff line change 1+ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_RECURSE_N_H
2+ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_RECURSE_N_H
3+
4+ #include " utils/exception.h"
5+
6+ namespace FlexFlow {
7+
8+ /* *
9+ * @brief
10+ * Applies function `f` to value `initial_value` n times recursively.
11+ *
12+ * @example
13+ * auto add_three = [](int x) { return x + 3; };
14+ * int result = recurse_n(add_three, 3, 5);
15+ * result -> f(f(f(5))) = ((5+3)+3)+3 = 14
16+ *
17+ * @throws RuntimeError if n is negative
18+ */
19+ template <typename F, typename T>
20+ T recurse_n (F const &f, int n, T const &initial_value) {
21+ if (n < 0 ) {
22+ throw mk_runtime_error (
23+ fmt::format (" Supplied n={} should be non-negative" , n));
24+ }
25+ T t = initial_value;
26+ for (int i = 0 ; i < n; i++) {
27+ t = f (t);
28+ }
29+ return t;
30+ }
31+
32+ } // namespace FlexFlow
33+
34+ #endif
Original file line number Diff line number Diff line change 1+ #include " utils/containers/recurse_n.h"
2+ #include " utils/archetypes/value_type.h"
3+ #include < functional>
4+
5+ namespace FlexFlow {
6+
7+ using T = value_type<0 >;
8+ using F = std::function<T(T)>; // F :: T -> T
9+
10+ template T recurse_n (F const &f, int n, T const &initial_value);
11+
12+ } // namespace FlexFlow
Original file line number Diff line number Diff line change 1+ #include " utils/containers/recurse_n.h"
2+ #include < doctest/doctest.h>
3+ #include < string>
4+
5+ using namespace FlexFlow ;
6+
7+ TEST_SUITE (FF_TEST_SUITE) {
8+ TEST_CASE (" recurse_n" ) {
9+ auto append_bar = [](std::string const &x) {
10+ return x + std::string (" Bar" );
11+ };
12+
13+ SUBCASE (" n = 0" ) {
14+ std::string result = recurse_n (append_bar, 0 , std::string (" Foo" ));
15+ std::string correct = " Foo" ;
16+ CHECK (result == correct);
17+ }
18+
19+ SUBCASE (" n = 3" ) {
20+ std::string result = recurse_n (append_bar, 3 , std::string (" Foo" ));
21+ std::string correct = " FooBarBarBar" ;
22+ CHECK (result == correct);
23+ }
24+
25+ SUBCASE (" n < 0" ) {
26+ CHECK_THROWS (recurse_n (append_bar, -1 , std::string (" Foo" )));
27+ }
28+ }
29+ }
You can’t perform that action at this time.
0 commit comments