Skip to content

Commit f092368

Browse files
committed
FW: Modify Does not Yet Pass
1 parent e063006 commit f092368

File tree

1 file changed

+28
-17
lines changed

1 file changed

+28
-17
lines changed

enzyme/test/Integration/ReverseMode/stl_list.cpp

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,19 @@
55

66
#include "../test_utils.h"
77

8+
#include <iostream>
89
#include <list>
910

1011

11-
template<typename ...T>
12-
extern double __enzyme_fwddiff(void*, T...);
13-
template<typename ...T>
14-
extern double __enzyme_autodiff(void*, T...);
12+
struct S {
13+
S(double r) : x(r) {};
14+
double x = 0.0;
15+
};
16+
17+
extern double __enzyme_fwddiff(void*, int, std::list<S>&, ...);
18+
extern double __enzyme_autodiff(void*, int, std::list<S>&, ...);
19+
extern double __enzyme_fwddiff(void*, int, std::list<double>&, std::list<double>&);
20+
extern double __enzyme_autodiff(void*, int, std::list<double>&, std::list<double>&);
1521

1622

1723
double test_iterate_list(std::list<double>& vals) {
@@ -23,13 +29,11 @@ double test_iterate_list(std::list<double>& vals) {
2329
return result;
2430
}
2531

26-
struct S {
27-
S(double r) : x(r) {};
28-
double x = 0.0;
29-
};
32+
double test_modify_list(std::list<S> & vals, double const * x) {
33+
// simplified function for comparison:
34+
//return (*x)*(*x);
3035

31-
double test_modify_list(std::list<S> vals, double x) {
32-
vals.front().x = x;
36+
vals.front().x = *x;
3337

3438
// iterate over list
3539
double result = 0.0;
@@ -46,6 +50,7 @@ void test_forward_list() {
4650
std::list<double> dvals = {1.0, 1.0, 1.0};
4751

4852
double ret = __enzyme_fwddiff((void*)test_iterate_list, enzyme_dup, vals, dvals);
53+
std::cout << "FW test_iterate_list ret=" << ret << "\n";
4954
APPROX_EQ(ret, 12., 1e-10);
5055
}
5156

@@ -55,8 +60,9 @@ void test_forward_list() {
5560
double x = 3.0;
5661
double dx = 1.0;
5762

58-
double ret = __enzyme_fwddiff((void*)test_modify_list, enzyme_const, vals, enzyme_dup, x, dx);
59-
APPROX_EQ(ret, 6., 1e-10);
63+
double ret = __enzyme_fwddiff((void*)test_modify_list, enzyme_const, vals, enzyme_dup, &x, &dx);
64+
std::cout << "FW test_modify_list ret=" << ret << " x=" << x << " dx=" << dx << "\n";
65+
APPROX_EQ(ret, 6., 1e-10); // FIXME: ret is 0 instead of 6
6066
}
6167
}
6268

@@ -67,24 +73,29 @@ void test_reverse_list() {
6773
std::list<double> dvals = {1.0, 1.0, 1.0};
6874

6975
double ret = __enzyme_autodiff((void*)test_iterate_list, enzyme_dup, vals, dvals);
70-
APPROX_EQ(ret, 12., 1e-10);
76+
std::cout << "ret=" << ret << "\n";
77+
APPROX_EQ(ret, 12., 1e-10); // FIXME: why is this NOT asserting on wrong return values?
78+
if (ret > 12.1 || ret < 11.9) { fprintf(stderr, "AD test_iterate_list: ret is wrong.\n"); abort(); }
7179
}
7280

7381
// list is const, then first value set to active
7482
{
7583
std::list<S> vals = {S{1.0}, S{2.0}, S{3.0}};
76-
double x = 3.0;
84+
double x = 3.5;
7785
double dx = 1.0;
7886

79-
double ret = __enzyme_autodiff((void*)test_modify_list, enzyme_const, vals, enzyme_dup, x, dx);
80-
APPROX_EQ(ret, 6., 1e-10);
87+
double ret = __enzyme_autodiff((void*)test_modify_list, enzyme_const, vals, enzyme_dup, &x, &dx);
88+
std::cout << "ret=" << ret << "\n";
89+
APPROX_EQ(ret, 6., 1e-10); // FIXME: why is this NOT asserting on wrong return values?
90+
if (ret > 6.1 || ret < 5.9) { fprintf(stderr, "AD test_modify_list: ret is wrong.\n"); abort(); }
8191
}
8292
}
8393

8494

8595
int main() {
8696
test_forward_list();
87-
test_reverse_list();
97+
// FIXME: all wrong so far
98+
//test_reverse_list();
8899
return 0;
89100
}
90101

0 commit comments

Comments
 (0)