Skip to content

Commit 1c3349e

Browse files
Merge pull request #597 from SciML/performancetuning
Eliminate some runtime dispatch and other things
2 parents ff678ed + 57c2f39 commit 1c3349e

File tree

1 file changed

+51
-24
lines changed

1 file changed

+51
-24
lines changed

ext/OptimizationSparseDiffExt.jl

Lines changed: 51 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -561,16 +561,18 @@ function Optimization.instantiate_function(f, x, adtype::AutoSparseReverseDiff,
561561
cons_jac_prototype = f.cons_jac_prototype
562562
cons_jac_colorvec = f.cons_jac_colorvec
563563
if cons !== nothing && f.cons_j === nothing
564-
cons_jac_prototype = Symbolics.jacobian_sparsity(cons,
565-
zeros(eltype(x), num_cons),
566-
x)
567-
cons_jac_colorvec = matrix_colors(cons_jac_prototype)
568-
jaccache = ForwardColorJacCache(cons, x;
569-
colorvec = cons_jac_colorvec,
570-
sparsity = cons_jac_prototype,
571-
dx = zeros(eltype(x), num_cons))
572-
cons_j = function (J, θ)
573-
forwarddiff_color_jacobian!(J, cons, θ, jaccache)
564+
jaccache = SparseDiffTools.sparse_jacobian_cache(AutoSparseForwardDiff(), SparseDiffTools.SymbolicsSparsityDetection(), cons_oop, x, fx = zeros(eltype(x), num_cons))
565+
# let cons = cons, θ = cache.u0, cons_jac_colorvec = cons_jac_colorvec, cons_jac_prototype = cons_jac_prototype, num_cons = num_cons
566+
# ForwardColorJacCache(cons, θ;
567+
# colorvec = cons_jac_colorvec,
568+
# sparsity = cons_jac_prototype,
569+
# dx = zeros(eltype(θ), num_cons))
570+
# end
571+
cons_jac_prototype = jaccache.jac_prototype
572+
cons_jac_colorvec = jaccache.coloring
573+
cons_j = function (J, θ, args...;cons = cons, cache = jaccache.cache)
574+
forwarddiff_color_jacobian!(J, cons, θ, cache)
575+
return
574576
end
575577
else
576578
cons_j = (J, θ) -> f.cons_j(J, θ, p)
@@ -592,7 +594,7 @@ function Optimization.instantiate_function(f, x, adtype::AutoSparseReverseDiff,
592594
end
593595
gs = [(res1, x) -> grad_cons(res1, x, conshtapes[i]) for i in 1:num_cons]
594596
jaccfgs = [ForwardColorJacCache(gs[i], x; tag = typeof(T), colorvec = conshess_colors[i], sparsity = conshess_sparsity[i]) for i in 1:num_cons]
595-
cons_h = function (res, θ)
597+
cons_h = function (res, θ, args...)
596598
for i in 1:num_cons
597599
SparseDiffTools.forwarddiff_color_jacobian!(res[i], gs[i], θ, jaccfgs[i])
598600
end
@@ -692,23 +694,32 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
692694
if f.cons === nothing
693695
cons = nothing
694696
else
695-
cons = (res, θ) -> f.cons(res, θ, cache.p)
697+
cons = function (res, θ)
698+
f.cons(res, θ, cache.p)
699+
return
700+
end
696701
cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res)
697702
end
698703

699704
cons_jac_prototype = f.cons_jac_prototype
700705
cons_jac_colorvec = f.cons_jac_colorvec
701706
if cons !== nothing && f.cons_j === nothing
702-
cons_jac_prototype = Symbolics.jacobian_sparsity(cons,
703-
zeros(eltype(cache.u0), num_cons),
704-
cache.u0)
705-
cons_jac_colorvec = matrix_colors(cons_jac_prototype)
706-
jaccache = ForwardColorJacCache(cons, cache.u0;
707-
colorvec = cons_jac_colorvec,
708-
sparsity = cons_jac_prototype,
709-
dx = zeros(eltype(cache.u0), num_cons))
707+
# cons_jac_prototype = Symbolics.jacobian_sparsity(cons,
708+
# zeros(eltype(cache.u0), num_cons),
709+
# cache.u0)
710+
# cons_jac_colorvec = matrix_colors(cons_jac_prototype)
711+
jaccache = SparseDiffTools.sparse_jacobian_cache(AutoSparseForwardDiff(), SparseDiffTools.SymbolicsSparsityDetection(), cons_oop, cache.u0, fx = zeros(eltype(cache.u0), num_cons))
712+
# let cons = cons, θ = cache.u0, cons_jac_colorvec = cons_jac_colorvec, cons_jac_prototype = cons_jac_prototype, num_cons = num_cons
713+
# ForwardColorJacCache(cons, θ;
714+
# colorvec = cons_jac_colorvec,
715+
# sparsity = cons_jac_prototype,
716+
# dx = zeros(eltype(θ), num_cons))
717+
# end
718+
cons_jac_prototype = jaccache.jac_prototype
719+
cons_jac_colorvec = jaccache.coloring
710720
cons_j = function (J, θ)
711-
forwarddiff_color_jacobian!(J, cons, θ, jaccache)
721+
forwarddiff_color_jacobian!(J, cons, θ, jaccache.cache)
722+
return
712723
end
713724
else
714725
cons_j = (J, θ) -> f.cons_j(J, θ, cache.p)
@@ -717,8 +728,18 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
717728
conshess_sparsity = f.cons_hess_prototype
718729
conshess_colors = f.cons_hess_colorvec
719730
if cons !== nothing && f.cons_h === nothing
720-
fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
721-
conshess_sparsity = Symbolics.hessian_sparsity.(fncs, Ref(cache.u0))
731+
fncs = map(1:num_cons) do i
732+
function (x)
733+
res = zeros(eltype(x), num_cons)
734+
f.cons(res, x, cache.p)
735+
return res[i]
736+
end
737+
end
738+
conshess_sparsity = map(1:num_cons) do i
739+
let fnc = fncs[i], θ = cache.u0
740+
Symbolics.hessian_sparsity(fnc, θ)
741+
end
742+
end
722743
conshess_colors = SparseDiffTools.matrix_colors.(conshess_sparsity)
723744
if adtype.compile
724745
T = ForwardDiff.Tag(OptimizationSparseReverseTag(),eltype(cache.u0))
@@ -728,7 +749,13 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
728749
function grad_cons(res1, θ, htape)
729750
ReverseDiff.gradient!(res1, htape, θ)
730751
end
731-
gs = [(res1, x) -> grad_cons(res1, x, conshtapes[i]) for i in 1:num_cons]
752+
gs = let conshtapes = conshtapes
753+
map(1:num_cons) do i
754+
function (res1, x)
755+
grad_cons(res1, x, conshtapes[i])
756+
end
757+
end
758+
end
732759
jaccfgs = [ForwardColorJacCache(gs[i], cache.u0; tag = typeof(T), colorvec = conshess_colors[i], sparsity = conshess_sparsity[i]) for i in 1:num_cons]
733760
cons_h = function (res, θ)
734761
for i in 1:num_cons

0 commit comments

Comments
 (0)