diff --git a/src/DerivativeFree/sidi.jl b/src/DerivativeFree/sidi.jl index d0b2af6..b9037b3 100644 --- a/src/DerivativeFree/sidi.jl +++ b/src/DerivativeFree/sidi.jl @@ -36,8 +36,8 @@ end function init_state(M::Sidi{k}, F::Callable_Function, x) where {k} x₀, x₁ = x₀x₁(x) - fx₀, xs, fs = _init_sidi(F, (x₀, x₁), k) - state = SidiState(xs[k], xs[k+1], fx₀, fs[1], xs, fs) + fx₀, xs, fs = _init_sidi(F, (x₀, x₁), k) + state = SidiState(xs[k], xs[k + 1], fx₀, fs[1], xs, fs) end function update_state( @@ -46,15 +46,14 @@ function update_state( o::SidiState{T,S}, options, l=NullTracks(), -) where {k, T, S} - +) where {k,T,S} xs, fs = o.xs, o.fs fxn1 = o.fxn1 _update_sidi!(F, xs, fs) incfn(l) @reset o.xn0 = xs[k] - @reset o.xn1 = xs[k+1] + @reset o.xn1 = xs[k + 1] @reset o.fxn0 = o.fxn1 @reset o.fxn1 = fs[1] @reset o.xs = xs @@ -79,7 +78,7 @@ function _init_sidi(f, x, k) x₀ = first(x) fx₀ = f(x₀) - xs = Vector{typeof( x₀)}(undef, k+1) + xs = Vector{typeof(x₀)}(undef, k+1) fs = Vector{typeof(fx₀)}(undef, k+1) n = length(x) @@ -93,64 +92,62 @@ function _init_sidi(f, x, k) fs[2] = (fx₀ - fs[1]) / (xs[1] - xs[2]) # build up diagonal by diagonal - for j ∈ 3:(k+1) + for j in 3:(k + 1) if j ≤ n # xⱼ was specified xⱼ = xs[j] else - xⱼ₋₁ = xs[j-1] - pk′ = evaluate_pk′(view(xs, 1:j-1), view(fs, 1:j-1)) + xⱼ₋₁ = xs[j - 1] + pk′ = evaluate_pk′(view(xs, 1:(j - 1)), view(fs, 1:(j - 1))) xⱼ = xs[j] = xⱼ₋₁ - fs[1] / pk′ end Δ = f(xⱼ) - for i ∈ 2:j - Δ₀ = fs[i-1] - fs[i-1] = Δ - Δ = (Δ₀ - Δ) / (xs[j-i+1] - xs[j]) + for i in 2:j + Δ₀ = fs[i - 1] + fs[i - 1] = Δ + Δ = (Δ₀ - Δ) / (xs[j - i + 1] - xs[j]) end fs[j] = Δ end # return fx₀ for bookkeeping purposes fx₀, xs, fs - end # update step: compute xn, fxn, update the xs,fs tableau function _update_sidi!(f, xs, fs) xₙ₋₁, fxₙ₋₁ = xs[end], fs[1] fn′ = evaluate_pk′(xs, fs) - xn = xₙ₋₁ - fxₙ₋₁ / fn′ + xn = xₙ₋₁ - fxₙ₋₁ / fn′ fxn = f(xn) update_tableau!(xn, fxn, xs, fs) end # formula (10) in paper to evaluate derivative of interpolating polynomial function evaluate_pk′(xs1, fs1) - δ = xs1[end] - xs1[end-1] + δ = xs1[end] - xs1[end - 1] Σ = fs1[2] k = length(xs1) - for i ∈ 3:k + for i in 3:k Σ = Σ + fs1[i] * δ - δ = δ * (xs1[end] - xs1[end-i+1]) + δ = δ * (xs1[end] - xs1[end - i + 1]) end Σ end - # update tableau's lower part # leaves [xn-k, xn-k+1, xn-k+2, ..., xn] # [fn, f(n-1,n), f(n-2, n-1, n), ..., f(n-k,n-k+1, ..., n)] function update_tableau!(xn, fxn, xs0, fs0) k = length(xs0) - for i in 1:k-1 - xs0[i] = xs0[i+1] + for i in 1:(k - 1) + xs0[i] = xs0[i + 1] end xs0[end] = xn Δ = fxn for i in 2:k - Δ₀ = fs0[i-1] - fs0[i-1] = Δ - Δ = (Δ₀ - Δ) / (xs0[end-i+1] - xn) + Δ₀ = fs0[i - 1] + fs0[i - 1] = Δ + Δ = (Δ₀ - Δ) / (xs0[end - i + 1] - xn) end fs0[end] = Δ xs0, fs0 diff --git a/src/find_zero.jl b/src/find_zero.jl index ecb2811..50e73ec 100644 --- a/src/find_zero.jl +++ b/src/find_zero.jl @@ -319,9 +319,9 @@ function init( end # helper for development use only -function __init(f,x,M,p=nothing; kwargs...) - s = init(ZeroProblem(f,x), M, p;kwargs...) - (M=s.M, F=s.F, state=s.state, options=s.options,logger=s.logger) +function __init(f, x, M, p=nothing; kwargs...) + s = init(ZeroProblem(f, x), M, p; kwargs...) + (M=s.M, F=s.F, state=s.state, options=s.options, logger=s.logger) end """ diff --git a/test/test_derivative_free.jl b/test/test_derivative_free.jl index 1cca4ac..4d8fa93 100644 --- a/test/test_derivative_free.jl +++ b/test/test_derivative_free.jl @@ -307,7 +307,7 @@ if !isinteractive() Roots.Order8(), Roots.Order16(), Roots.Sidi(2), - Roots.Sidi(3) + Roots.Sidi(3), ] results = [run_df_tests((f, b) -> find_zero(f, b, M), name="$M") for M in Ms] @@ -358,7 +358,7 @@ if !isinteractive() Roots.Order5(), Roots.Order8(), Roots.Order16(), - Roots.Sidi(2) + Roots.Sidi(2), ] Ts = [Float16, Float32, BigFloat] diff --git a/test/test_find_zero.jl b/test/test_find_zero.jl index 4b8dd4a..3199d33 100644 --- a/test/test_find_zero.jl +++ b/test/test_find_zero.jl @@ -28,7 +28,7 @@ struct Order3_Test <: Roots.AbstractSecantMethod end Roots.Thukral16(), Roots.LithBoonkkampIJzerman(3, 0), Roots.LithBoonkkampIJzerman(4, 0), - Roots.Sidi(2) + Roots.Sidi(2), ] ## different types of initial values @@ -597,8 +597,8 @@ end end @testset "similar methods" begin - Lsidi,Lsec = Roots.Tracks(),Roots.Tracks() + Lsidi, Lsec = Roots.Tracks(), Roots.Tracks() find_zero(sin, 3.0, Roots.Sidi(1); tracks=Lsidi) find_zero(sin, 3.0, Roots.Secant(); tracks=Lsec) - @test Lsidi.xfₛ[3:end] == Lsec.xfₛ[3:end] # drop x₀x₁ ordering + @test Lsidi.xfₛ[3:end] == Lsec.xfₛ[3:end] # drop x₀x₁ ordering end