Skip to content

Commit 254ce8b

Browse files
committed
return build_solution in scalar.jl
1 parent 17898cf commit 254ce8b

File tree

1 file changed

+22
-18
lines changed

1 file changed

+22
-18
lines changed

src/scalar.jl

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@ function SciMLBase.solve(prob::NonlinearProblem{<:Number}, alg::NewtonRaphson, a
1414
fx = f(x)
1515
dfx = FiniteDiff.finite_difference_derivative(f, x, alg.diff_type, eltype(x), fx)
1616
end
17-
iszero(fx) && return NewtonSolution(x, DEFAULT, fx)
17+
iszero(fx) && return SciMLBase.build_solution(prob, alg, x, fx; retcode=Symbol(DEFAULT))
1818
Δx = dfx \ fx
1919
x -= Δx
2020
if isapprox(x, xo, atol=atol, rtol=rtol)
21-
return NewtonSolution(x, DEFAULT, fx)
21+
return SciMLBase.build_solution(prob, alg, x, fx; retcode=Symbol(DEFAULT))
2222
end
2323
xo = x
2424
end
25-
return NewtonSolution(x, MAXITERS_EXCEED, fx)
25+
return SciMLBase.build_solution(prob, alg, x, fx; retcode=Symbol(MAXITERS_EXCEED))
2626
end
2727

2828
function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
@@ -33,7 +33,7 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
3333
newprob = NonlinearProblem(f, u0, p; prob.kwargs...)
3434
sol = solve(newprob, alg, args...; kwargs...)
3535

36-
uu = getsolution(sol)
36+
uu = sol.u
3737
if p isa Number
3838
f_p = ForwardDiff.derivative(Base.Fix1(f, uu), p)
3939
else
@@ -51,39 +51,43 @@ end
5151

5252
function SciMLBase.solve(prob::NonlinearProblem{<:Number, iip, <:Dual{T,V,P}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
5353
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
54-
return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode, sol.resid)
54+
return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode)
55+
#return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode, sol.resid)
5556
end
5657
function SciMLBase.solve(prob::NonlinearProblem{<:Number, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
5758
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
58-
return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode, sol.resid)
59+
return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode)
60+
#return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode, sol.resid)
5961
end
6062

6163
# avoid ambiguities
6264
for Alg in [Bisection]
6365
@eval function SciMLBase.solve(prob::NonlinearProblem{uType, iip, <:Dual{T,V,P}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P}
6466
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
65-
return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid)
67+
return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode,left = Dual{T,V,P}(sol.left, partials), right = Dual{T,V,P}(sol.right, partials))
68+
#return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid)
6669
end
6770
@eval function SciMLBase.solve(prob::NonlinearProblem{uType, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P}
6871
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
69-
return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid)
72+
return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode,left = Dual{T,V,P}(sol.left, partials), right = Dual{T,V,P}(sol.right, partials))
73+
#return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid)
7074
end
7175
end
7276

73-
function SciMLBase.solve(prob::NonlinearProblem, ::Bisection, args...; maxiters = 1000, kwargs...)
77+
function SciMLBase.solve(prob::NonlinearProblem, alg::Bisection, args...; maxiters = 1000, kwargs...)
7478
f = Base.Fix2(prob.f, prob.p)
7579
left, right = prob.u0
7680
fl, fr = f(left), f(right)
7781

7882
if iszero(fl)
79-
return BracketingSolution(left, right, EXACT_SOLUTION_LEFT,fl)
83+
return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(EXACT_SOLUTION_LEFT), left = left, right = right)
8084
end
8185

8286
i = 1
8387
if !iszero(fr)
8488
while i < maxiters
8589
mid = (left + right) / 2
86-
(mid == left || mid == right) && return BracketingSolution(left, right, FLOATING_POINT_LIMIT, fl)
90+
(mid == left || mid == right) && return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(FLOATING_POINT_LIMIT), left = left, right = right)
8791
fm = f(mid)
8892
if iszero(fm)
8993
right = mid
@@ -102,7 +106,7 @@ function SciMLBase.solve(prob::NonlinearProblem, ::Bisection, args...; maxiters
102106

103107
while i < maxiters
104108
mid = (left + right) / 2
105-
(mid == left || mid == right) && return BracketingSolution(left, right, FLOATING_POINT_LIMIT, fl)
109+
(mid == left || mid == right) && return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(FLOATING_POINT_LIMIT), left = left, right = right)
106110
fm = f(mid)
107111
if iszero(fm)
108112
right = mid
@@ -114,23 +118,23 @@ function SciMLBase.solve(prob::NonlinearProblem, ::Bisection, args...; maxiters
114118
i += 1
115119
end
116120

117-
return BracketingSolution(left, right, MAXITERS_EXCEED,fl)
121+
return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(MAXITERS_EXCEED), left = left, right = right)
118122
end
119123

120-
function SciMLBase.solve(prob::NonlinearProblem, ::Falsi, args...; maxiters = 1000, kwargs...)
124+
function SciMLBase.solve(prob::NonlinearProblem, alg::Falsi, args...; maxiters = 1000, kwargs...)
121125
f = Base.Fix2(prob.f, prob.p)
122126
left, right = prob.u0
123127
fl, fr = f(left), f(right)
124128

125129
if iszero(fl)
126-
return BracketingSolution(left, right, EXACT_SOLUTION_LEFT,fl)
130+
return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(EXACT_SOLUTION_LEFT), left = left, right = right)
127131
end
128132

129133
i = 1
130134
if !iszero(fr)
131135
while i < maxiters
132136
if nextfloat_tdir(left, prob.u0...) == right
133-
return BracketingSolution(left, right, FLOATING_POINT_LIMIT, fx)
137+
return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(FLOATING_POINT_LIMIT), left = left, right = right)
134138
end
135139
mid = (fr * left - fl * right) / (fr - fl)
136140
for i in 1:10
@@ -157,7 +161,7 @@ function SciMLBase.solve(prob::NonlinearProblem, ::Falsi, args...; maxiters = 10
157161

158162
while i < maxiters
159163
mid = (left + right) / 2
160-
(mid == left || mid == right) && return BracketingSolution(left, right, FLOATING_POINT_LIMIT, fl)
164+
(mid == left || mid == right) && return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(FLOATING_POINT_LIMIT), left = left, right = right)
161165
fm = f(mid)
162166
if iszero(fm)
163167
right = mid
@@ -172,5 +176,5 @@ function SciMLBase.solve(prob::NonlinearProblem, ::Falsi, args...; maxiters = 10
172176
i += 1
173177
end
174178

175-
return BracketingSolution(left, right, MAXITERS_EXCEED,fl)
179+
return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(MAXITERS_EXCEED), left = left, right = right)
176180
end

0 commit comments

Comments
 (0)