@@ -26,7 +26,7 @@ p = [σ => 28.0,
26
26
β => 8 / 3 ]
27
27
28
28
tspan = (0.0 , 100.0 )
29
- prob = ODEProblem (sys, u0, tspan, p, jac = true )
29
+ prob = ODEProblem (sys, u0, tspan, p)
30
30
sol = solve (prob, Tsit5 ())
31
31
32
32
@testset " AutoDiff Observable Functions" begin
@@ -35,25 +35,27 @@ sol = solve(prob, Tsit5())
35
35
end
36
36
du_ = [0.0 , 1.0 , 1.0 , 1.0 ]
37
37
du = [du_ for _ in sol. u]
38
- @test du == gs
38
+ @test du == gs. u
39
39
40
40
# Observable in a vector
41
41
gs, = gradient (sol) do sol
42
42
sum (sum .(sol[[sys. w, sys. x]]))
43
43
end
44
44
du_ = [0.0 , 1.0 , 1.0 , 2.0 ]
45
45
du = [du_ for _ in sol. u]
46
- @test du == gs
46
+ @test du == gs. u
47
47
end
48
48
49
- # @testset "AD Observable Functions for Initialization" begin
50
- # iprob = prob.f.initialization_data.initalizeprob
51
- # isol = solve(iprob)
52
- # gs, = gradient(isol) do isol
53
- # isol[w]
54
- # end
49
+ @testset " AD Observable Functions for Initialization" begin
50
+ iprob = prob. f. initialization_data. initializeprob
51
+ isol = solve (iprob)
52
+ gs, = Zygote . gradient (isol) do isol
53
+ isol[w]
54
+ end
55
55
56
- # end
56
+ @test gs isa NamedTuple
57
+ @test isempty (setdiff (fieldnames (typeof (gs)), fieldnames (typeof (isol))))
58
+ end
57
59
58
60
# DAE
59
61
92
94
end
93
95
du_ = [0.2 , 1.0 ]
94
96
du = [du_ for _ in sol. u]
95
- @test gs == du
97
+ @test gs. u == du
96
98
97
99
@testset " DAE Initialization Observable function AD" begin
98
100
iprob = prob. f. initialization_data. initializeprob
0 commit comments