49
49
function ITP (; scaled_k1:: Real = 0.2 , k2:: Real = 2 , n0:: Int = 10 )
50
50
scaled_k1 < 0 && error (" Hyper-parameter κ₁ should not be negative" )
51
51
n0 < 0 && error (" Hyper-parameter n₀ should not be negative" )
52
- if k2 < 1 || k2 > ( 1.5 + sqrt (5 ) / 2 )
52
+ if ! ( 1 <= k2 <= 1.5 + sqrt (5 ) / 2 )
53
53
throw (ArgumentError (" Hyper-parameter κ₂ should be between 1 and 1 + ϕ where \
54
54
ϕ ≈ 1.618... is the golden ratio" ))
55
55
end
@@ -63,22 +63,23 @@ function CommonSolve.solve(
63
63
@assert ! SciMLBase. isinplace (prob) " `ITP` only supports out-of-place problems."
64
64
65
65
f = Base. Fix2 (prob. f, prob. p)
66
- left, right = prob. tspan
66
+ left, right = minmax ( promote ( prob. tspan... ) ... )
67
67
fl, fr = f (left), f (right)
68
68
69
69
abstol = NonlinearSolveBase. get_tolerance (
70
70
left, abstol, promote_type (eltype (left), eltype (right))
71
71
)
72
72
73
+ stats = SciMLBase. NLStats (2 ,0 ,0 ,0 ,0 )
73
74
if iszero (fl)
74
75
return SciMLBase. build_solution (
75
- prob, alg, left, fl; retcode = ReturnCode. ExactSolutionLeft, left, right
76
+ prob, alg, left, fl; retcode = ReturnCode. ExactSolutionLeft, left, right, stats
76
77
)
77
78
end
78
79
79
80
if iszero (fr)
80
81
return SciMLBase. build_solution (
81
- prob, alg, right, fr; retcode = ReturnCode. ExactSolutionRight, left, right
82
+ prob, alg, right, fr; retcode = ReturnCode. ExactSolutionRight, left, right, stats
82
83
)
83
84
end
84
85
@@ -87,73 +88,63 @@ function CommonSolve.solve(
87
88
@warn " The interval is not an enclosing interval, opposite signs at the \
88
89
boundaries are required."
89
90
return SciMLBase. build_solution (
90
- prob, alg, left, fl; retcode = ReturnCode. InitialFailure, left, right
91
+ prob, alg, left, fl; retcode = ReturnCode. InitialFailure, left, right, stats
91
92
)
92
93
end
93
94
94
95
ϵ = abstol
95
96
k2 = alg. k2
96
- k1 = alg. scaled_k1 * abs (right - left)^ (1 - k2)
97
+ span = right - left
98
+ k1 = alg. scaled_k1 * span^ (1 - k2) # k1 > 0
97
99
n0 = alg. n0
98
- mid = (left + right) / 2
99
- x_f = left + (right - left) * (fl / (fl - fr))
100
- xt = left
101
- xp = left
102
- r = zero (left) # minmax radius
103
- δ = zero (left) # truncation error
104
- σ = one (mid)
105
- n_h = exponent (abs (right - left) / (2 * ϵ))
100
+ n_h = exponent (span / (2 * ϵ))
106
101
ϵ_s = ϵ * exp2 (n_h + n0)
102
+ T0 = zero (fl)
107
103
108
104
i = 1
109
105
while i ≤ maxiters
110
- span = abs (right - left)
106
+ stats. nsteps += 1
107
+ span = right - left
108
+ mid = (left + right) / 2
111
109
r = ϵ_s - (span / 2 )
112
- δ = k1 * span^ k2
113
110
114
- x_f = left + (right - left) * (fl / (fl - fr)) # Interpolation Step
111
+ x_f = left + span * (fl / (fl - fr)) # Interpolation Step
115
112
113
+ δ = max (k1 * span^ k2, eps (x_f))
116
114
diff = mid - x_f
117
- σ = sign (diff)
118
- xt = ifelse (δ ≤ diff, x_f + σ * δ, mid) # Truncation Step
119
115
120
- xp = ifelse (abs (xt - mid) ≤ r, xt , mid - σ * r ) # Projection Step
116
+ xt = ifelse (δ ≤ abs (diff), x_f + copysign (δ, diff) , mid) # Truncation Step
121
117
122
- if abs ((left - right) / 2 ) < ϵ
118
+ xp = ifelse (abs (xt - mid) ≤ r, xt, mid - copysign (r, diff)) # Projection Step
119
+ if span < 2 ϵ
123
120
return SciMLBase. build_solution (
124
- prob, alg, xt, f (xt); retcode = ReturnCode. Success, left, right
121
+ prob, alg, xt, f (xt); retcode = ReturnCode. Success, left, right, stats
125
122
)
126
123
end
127
-
128
- # update
129
- tmin, tmax = minmax (xt, xp)
130
- xp ≥ tmax && (xp = prevfloat (tmax))
131
- xp ≤ tmin && (xp = nextfloat (tmin))
132
124
yp = f (xp)
125
+ stats. nf += 1
133
126
yps = yp * sign (fr)
134
- T0 = zero (yps)
135
127
if yps > T0
136
128
right, fr = xp, yp
137
129
elseif yps < T0
138
130
left, fl = xp, yp
139
131
else
140
132
return SciMLBase. build_solution (
141
- prob, alg, xp, yps; retcode = ReturnCode. Success, left, right
133
+ prob, alg, xp, yps; retcode = ReturnCode. Success, left, right, stats
142
134
)
143
135
end
144
136
145
137
i += 1
146
- mid = (left + right) / 2
147
138
ϵ_s /= 2
148
139
149
- if Impl . nextfloat_tdir (left, prob . tspan ... ) == right
140
+ if nextfloat (left) == right
150
141
return SciMLBase. build_solution (
151
- prob, alg, right, fr; retcode = ReturnCode. FloatingPointLimit, left, right
142
+ prob, alg, right, fr; retcode = ReturnCode. FloatingPointLimit, left, right, stats
152
143
)
153
144
end
154
145
end
155
146
156
147
return SciMLBase. build_solution (
157
- prob, alg, left, fl; retcode = ReturnCode. MaxIters, left, right
148
+ prob, alg, left, fl; retcode = ReturnCode. MaxIters, left, right, stats
158
149
)
159
150
end
0 commit comments