Skip to content

Commit 8bfa3ba

Browse files
authored
Add recursion support (#14)
* allow lambda in a define to see outer scope so it can find itself * stop fixing up all closure statements at execution time * add tests proving recursion is working as expected
1 parent 691f9a7 commit 8bfa3ba

File tree

4 files changed

+82
-12
lines changed

4 files changed

+82
-12
lines changed

include/cons_expr/cons_expr.hpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -648,21 +648,18 @@ struct cons_expr
648648
}
649649

650650
// Closures contain all of their own scope
651-
LexicalScope new_scope;
651+
LexicalScope param_scope = scope;
652+
653+
// overwrite scope with the things we know we need params to be named
652654

653655
// set up params
654656
// technically I'm evaluating the params lazily while invoking the lambda, not before. Does it matter?
655657
for (const auto [name, parameter] : std::views::zip(engine.values[parameter_names], engine.values[params])) {
656-
new_scope.emplace_back(engine.get_if<identifier_type>(&name)->value, engine.eval(scope, parameter));
657-
}
658-
659-
Scratch fixed_statements{ engine.object_scratch };
660-
for (const auto &statement : engine.values[statements]) {
661-
fixed_statements.push_back(engine.fix_identifiers(statement, {}, new_scope));
658+
param_scope.emplace_back(engine.get_if<identifier_type>(&name)->value, engine.eval(scope, parameter));
662659
}
663660

664661
// TODO set up tail call elimination for last element of the sequence being evaluated?
665-
return engine.sequence(new_scope, engine.values.insert_or_find(fixed_statements));
662+
return engine.sequence(param_scope, statements);
666663
}
667664
};
668665

@@ -908,6 +905,7 @@ struct cons_expr
908905
auto locals = engine.get_lambda_parameter_names(engine.values[params[0]]);
909906

910907
// replace all references to captured values with constant copies
908+
// this is how we create the closure object
911909
Scratch fixed_statements{ engine.object_scratch };
912910

913911
for (const auto &statement : engine.values[params.sublist(1)]) {

src/ccons_expr/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] const char *argv[])
6767

6868
try {
6969
content_2 += to_string(evaluator,
70-
false,
70+
true,
7171
evaluator.sequence(
7272
evaluator.global_scope, std::get<lefticus::cons_expr<>::list_type>(evaluator.parse(content_1).first.value)));
7373
} catch (const std::exception &e) {

test/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ catch_discover_tests(
6565
.xml)
6666

6767
# Add a file containing a set of constexpr tests
68-
add_executable(constexpr_tests constexpr_tests.cpp list_tests.cpp parser_tests.cpp)
68+
add_executable(constexpr_tests constexpr_tests.cpp list_tests.cpp parser_tests.cpp recursion_tests.cpp)
6969
target_link_libraries(
7070
constexpr_tests
7171
PRIVATE cons_expr::cons_expr
@@ -93,7 +93,7 @@ catch_discover_tests(
9393

9494
# Disable the constexpr portion of the test, and build again this allows us to have an executable that we can debug when
9595
# things go wrong with the constexpr testing
96-
add_executable(relaxed_constexpr_tests constexpr_tests.cpp list_tests.cpp parser_tests.cpp)
96+
add_executable(relaxed_constexpr_tests constexpr_tests.cpp list_tests.cpp parser_tests.cpp recursion_tests.cpp)
9797
target_link_libraries(
9898
relaxed_constexpr_tests
9999
PRIVATE cons_expr::cons_expr
@@ -115,4 +115,4 @@ catch_discover_tests(
115115
OUTPUT_SUFFIX
116116
.xml)
117117

118-
target_include_directories(relaxed_constexpr_tests PRIVATE "${CMAKE_BINARY_DIR}/configured_files/include")
118+
target_include_directories(relaxed_constexpr_tests PRIVATE "${CMAKE_BINARY_DIR}/configured_files/include")

test/recursion_tests.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#include <catch2/catch_test_macros.hpp>
2+
#include <catch2/matchers/catch_matchers_floating_point.hpp>
3+
4+
#include <cons_expr/cons_expr.hpp>
5+
#include <cons_expr/utility.hpp>
6+
#include <internal_use_only/config.hpp>
7+
8+
using IntType = int;
9+
using FloatType = double;
10+
11+
template<typename Result> constexpr Result evaluate_to(std::string_view input)
12+
{
13+
lefticus::cons_expr<std::uint16_t, char, IntType, FloatType> evaluator;
14+
return evaluator.evaluate_to<Result>(input).value();
15+
}
16+
17+
template<typename Result> constexpr bool evaluate_expected(std::string_view input, auto result)
18+
{
19+
lefticus::cons_expr<std::uint16_t, char, IntType, FloatType> evaluator;
20+
return evaluator.evaluate_to<Result>(input).value() == result;
21+
}
22+
23+
TEST_CASE("Y-Combinator", "[recursion]")
24+
{
25+
STATIC_CHECK(evaluate_to<int>(
26+
R"(
27+
;; Y combinator definition
28+
(define Y
29+
(lambda (f)
30+
((lambda (x) (f (lambda (y) ((x x) y))))
31+
(lambda (x) (f (lambda (y) ((x x) y)))))))
32+
33+
;; Factorial using Y combinator
34+
(define factorial
35+
(Y (lambda (fact)
36+
(lambda (n)
37+
(if (== n 0)
38+
1
39+
(* n (fact (- n 1))))))))
40+
41+
(factorial 5)
42+
)") == 120);
43+
}
44+
45+
46+
TEST_CASE("expressive 'define' 1 level", "[recursion]")
47+
{
48+
STATIC_CHECK(evaluate_to<int>(
49+
R"(
50+
(define factorial
51+
(lambda (n)
52+
(if (== n 0)
53+
1
54+
(* n (factorial (- n 1))))))
55+
56+
(factorial 1)
57+
)") == 1);
58+
}
59+
60+
TEST_CASE("expressive 'define' 5 levels", "[recursion]")
61+
{
62+
STATIC_CHECK(evaluate_to<int>(
63+
R"(
64+
(define factorial
65+
(lambda (n)
66+
(if (== n 0)
67+
1
68+
(* n (factorial (- n 1))))))
69+
70+
(factorial 5)
71+
)") == 120);
72+
}

0 commit comments

Comments
 (0)