Skip to content

Commit e457a98

Browse files
Merge pull request #457 from ErikQQY/qqy/better_bigfloat
Better BigFloat support
2 parents 8dec0c9 + 86e6720 commit e457a98

File tree

6 files changed

+17
-25
lines changed

6 files changed

+17
-25
lines changed

src/default.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
266266
alias_u0 = false # If immutable don't care about aliasing
267267
end
268268
u0 = prob.u0
269-
u0_aliased = alias_u0 ? __similar(u0) : u0
269+
u0_aliased = alias_u0 ? zero(u0) : u0
270270
end]
271271
for i in 1:N
272272
cur_sol = sol_syms[i]

src/globalization/line_search.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ function __internal_init(
125125
deriv_op = nothing
126126
elseif SciMLBase.has_jvp(f)
127127
if isinplace(prob)
128-
jvp_cache = __similar(fu)
128+
jvp_cache = zero(fu)
129129
deriv_op = @closure (du, u, fu, p) -> begin
130130
f.jvp(jvp_cache, du, u, p)
131131
dot(fu, jvp_cache)
@@ -135,7 +135,7 @@ function __internal_init(
135135
end
136136
elseif SciMLBase.has_vjp(f)
137137
if isinplace(prob)
138-
vjp_cache = __similar(u)
138+
vjp_cache = zero(u)
139139
deriv_op = @closure (du, u, fu, p) -> begin
140140
f.vjp(vjp_cache, fu, u, p)
141141
dot(du, vjp_cache)
@@ -149,7 +149,7 @@ function __internal_init(
149149
alg.autodiff, prob; check_reverse_mode = true)
150150
vjp_op = VecJacOperator(prob, fu, u; autodiff)
151151
if isinplace(prob)
152-
vjp_cache = __similar(u)
152+
vjp_cache = zero(u)
153153
deriv_op = @closure (du, u, fu, p) -> dot(du, vjp_op(vjp_cache, fu, u, p))
154154
else
155155
deriv_op = @closure (du, u, fu, p) -> dot(du, vjp_op(fu, u, p))
@@ -159,7 +159,7 @@ function __internal_init(
159159
alg.autodiff, prob; check_forward_mode = true)
160160
jvp_op = JacVecOperator(prob, fu, u; autodiff)
161161
if isinplace(prob)
162-
jvp_cache = __similar(fu)
162+
jvp_cache = zero(fu)
163163
deriv_op = @closure (du, u, fu, p) -> dot(fu, jvp_op(jvp_cache, du, u, p))
164164
else
165165
deriv_op = @closure (du, u, fu, p) -> dot(fu, jvp_op(du, u, p))

src/internal/helpers.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
function evaluate_f(prob::AbstractNonlinearProblem{uType, iip}, u) where {uType, iip}
33
(; f, u0, p) = prob
44
if iip
5-
fu = f.resid_prototype === nothing ? __similar(u) :
5+
fu = f.resid_prototype === nothing ? zero(u) :
66
promote_type(eltype(u), eltype(f.resid_prototype)).(f.resid_prototype)
77
f(fu, u, p)
88
else
@@ -156,7 +156,7 @@ function __construct_extension_f(prob::AbstractNonlinearProblem; alias_u0::Bool
156156

157157
𝐅 = if force_oop === True && applicable(𝐟, u0, u0)
158158
_resid = resid isa Number ? [resid] : _vec(resid)
159-
du = _vec(__similar(_resid))
159+
du = _vec(zero(_resid))
160160
@closure u -> begin
161161
𝐟(du, u)
162162
return du

src/internal/jacobian.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ function JacobianCache(prob, alg, f::F, fu_, u, p; stats, autodiff = nothing,
8585
__similar(fu, promote_type(eltype(fu), eltype(u)), length(fu), length(u)) :
8686
copy(f.jac_prototype)
8787
elseif f.jac_prototype === nothing
88-
__init_bigfloat_array!!(init_jacobian(
88+
zero(init_jacobian(
8989
jac_cache; preserve_immutable = Val(true)))
9090
else
9191
f.jac_prototype

src/internal/operators.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ function JacobianOperator(prob::AbstractNonlinearProblem, fu, u; jvp_autodiff =
7474
@closure (v, u, p) -> auto_vecjac(uf, u, v)
7575
elseif vjp_autodiff isa AutoFiniteDiff
7676
if iip
77-
cache1 = __similar(fu)
78-
cache2 = __similar(fu)
77+
cache1 = zero(fu)
78+
cache2 = zero(fu)
7979
@closure (Jv, v, u, p) -> num_vecjac!(Jv, uf, u, v, cache1, cache2)
8080
else
8181
@closure (v, u, p) -> num_vecjac(uf, __mutable(u), v)
@@ -106,17 +106,17 @@ function JacobianOperator(prob::AbstractNonlinearProblem, fu, u; jvp_autodiff =
106106
if iip
107107
# FIXME: Technically we should propagate the tag but ignoring that for now
108108
cache1 = Dual{typeof(ForwardDiff.Tag(uf, eltype(u))), eltype(u),
109-
1}.(__similar(u), ForwardDiff.Partials.(tuple.(u)))
109+
1}.(zero(u), ForwardDiff.Partials.(tuple.(u)))
110110
cache2 = Dual{typeof(ForwardDiff.Tag(uf, eltype(fu))), eltype(fu),
111-
1}.(__similar(fu), ForwardDiff.Partials.(tuple.(fu)))
111+
1}.(zero(fu), ForwardDiff.Partials.(tuple.(fu)))
112112
@closure (Jv, v, u, p) -> auto_jacvec!(Jv, uf, u, v, cache1, cache2)
113113
else
114114
@closure (v, u, p) -> auto_jacvec(uf, u, v)
115115
end
116116
elseif jvp_autodiff isa AutoFiniteDiff
117117
if iip
118-
cache1 = __similar(fu)
119-
cache2 = __similar(u)
118+
cache1 = zero(fu)
119+
cache2 = zero(u)
120120
@closure (Jv, v, u, p) -> num_jacvec!(Jv, uf, u, v, cache1, cache2)
121121
else
122122
@closure (v, u, p) -> num_jacvec(uf, u, v)
@@ -162,15 +162,15 @@ end
162162
function (op::JacobianOperator{vjp, iip})(v, u, p) where {vjp, iip}
163163
if vjp
164164
if iip
165-
res = __similar(op.output_cache)
165+
res = zero(op.output_cache)
166166
op.vjp_op(res, v, u, p)
167167
return res
168168
else
169169
return op.vjp_op(v, u, p)
170170
end
171171
else
172172
if iip
173-
res = __similar(op.output_cache)
173+
res = zero(op.output_cache)
174174
op.jvp_op(res, v, u, p)
175175
return res
176176
else

src/utils.jl

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,5 @@ end
163163

164164
function __similar(x, args...; kwargs...)
165165
y = similar(x, args...; kwargs...)
166-
return __init_bigfloat_array!!(y)
167-
end
168-
169-
function __init_bigfloat_array!!(x)
170-
if ArrayInterface.can_setindex(x)
171-
eltype(x) <: BigFloat && fill!(x, BigFloat(0))
172-
return x
173-
end
174-
return x
166+
return zero(y)
175167
end

0 commit comments

Comments
 (0)