Skip to content

Commit 3505679

Browse files
Populate all colorvec fields and some fixes
1 parent af589d5 commit 3505679

File tree

1 file changed

+94
-41
lines changed

1 file changed

+94
-41
lines changed

ext/OptimizationSparseDiffExt.jl

Lines changed: 94 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ function Optimization.instantiate_function(f::OptimizationFunction{true}, x,
3333
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)
3434
end
3535

36+
hess_sparsity = f.hess_prototype
37+
hess_colors = f.hess_colorvec
3638
if f.hess === nothing
3739
hess_sparsity = Symbolics.hessian_sparsity(_f, x)
3840
hess_colors = matrix_colors(tril(hess_sparsity))
@@ -62,6 +64,8 @@ function Optimization.instantiate_function(f::OptimizationFunction{true}, x,
6264
cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res)
6365
end
6466

67+
cons_jac_prototype = f.cons_jac_prototype
68+
cons_jac_colorvec = f.cons_jac_colorvec
6569
if cons !== nothing && f.cons_j === nothing
6670
cons_jac_prototype = Symbolics.jacobian_sparsity(cons, zeros(eltype(x), num_cons),
6771
x)
@@ -79,6 +83,7 @@ function Optimization.instantiate_function(f::OptimizationFunction{true}, x,
7983
cons_j = (J, θ) -> f.cons_j(J, θ, p)
8084
end
8185

86+
cons_hess_caches = [(; sparsity = f.cons_hess_prototype, colors = f.cons_hess_colorvec)]
8287
if cons !== nothing && f.cons_h === nothing
8388
function gen_conshess_cache(_f, x)
8489
conshess_sparsity = copy(Symbolics.hessian_sparsity(_f, x))
@@ -91,9 +96,10 @@ function Optimization.instantiate_function(f::OptimizationFunction{true}, x,
9196
fcons = [(x) -> (_res = zeros(eltype(x), num_cons);
9297
cons(_res, x);
9398
_res[i]) for i in 1:num_cons]
99+
cons_hess_caches = gen_conshess_cache.(fcons, Ref(x))
94100
cons_h = function (res, θ)
95101
for i in 1:num_cons
96-
numauto_color_hessian!(res[i], fcons[i], θ, gen_conshess_cache(fcons[i], θ))
102+
numauto_color_hessian!(res[i], fcons[i], θ, cons_hess_caches[i])
97103
end
98104
end
99105
else
@@ -107,9 +113,12 @@ function Optimization.instantiate_function(f::OptimizationFunction{true}, x,
107113
end
108114
return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv,
109115
cons = cons, cons_j = cons_j, cons_h = cons_h,
110-
hess_prototype = f.hess_prototype,
111-
cons_jac_prototype = f.cons_jac_prototype,
112-
cons_hess_prototype = f.cons_hess_prototype,
116+
hess_prototype = hess_sparsity,
117+
hess_colorvec = hess_colors,
118+
cons_jac_colorvec = cons_jac_colorvec,
119+
cons_jac_prototype = cons_jac_prototype,
120+
cons_hess_prototype = getfield.(cons_hess_caches, :sparsity),
121+
cons_hess_colorvec = getfield.(cons_hess_caches, :colors),
113122
lag_h, f.lag_hess_prototype)
114123
end
115124

@@ -132,6 +141,8 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
132141
grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...)
133142
end
134143

144+
hess_sparsity = f.hess_prototype
145+
hess_colors = f.hess_colorvec
135146
if f.hess === nothing
136147
hess_sparsity = Symbolics.hessian_sparsity(_f, cache.u0)
137148
hess_colors = matrix_colors(tril(hess_sparsity))
@@ -161,6 +172,8 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
161172
cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res)
162173
end
163174

175+
cons_jac_prototype = f.cons_jac_prototype
176+
cons_jac_colorvec = f.cons_jac_colorvec
164177
if cons !== nothing && f.cons_j === nothing
165178
cons_jac_prototype = Symbolics.jacobian_sparsity(cons,
166179
zeros(eltype(cache.u0), num_cons),
@@ -177,6 +190,7 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
177190
cons_j = (J, θ) -> f.cons_j(J, θ, cache.p)
178191
end
179192

193+
cons_hess_caches = [(; sparsity = f.cons_hess_prototype, colors = f.cons_hess_colorvec)]
180194
if cons !== nothing && f.cons_h === nothing
181195
function gen_conshess_cache(_f, x)
182196
conshess_sparsity = copy(Symbolics.hessian_sparsity(_f, x))
@@ -208,8 +222,11 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
208222
return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv,
209223
cons = cons, cons_j = cons_j, cons_h = cons_h,
210224
hess_prototype = hess_sparsity,
225+
hess_colorvec = hess_colors,
211226
cons_jac_prototype = cons_jac_prototype,
227+
cons_jac_colorvec = cons_jac_colorvec,
212228
cons_hess_prototype = getfield.(cons_hess_caches, :sparsity),
229+
cons_hess_colorvec = getfield.(cons_hess_caches, :colors),
213230
lag_h, f.lag_hess_prototype)
214231
end
215232

@@ -230,7 +247,9 @@ function Optimization.instantiate_function(f, x, adtype::AutoSparseFiniteDiff, p
230247
else
231248
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)
232249
end
233-
250+
251+
hess_sparsity = f.hess_prototype
252+
hess_colors = f.hess_colorvec
234253
if f.hess === nothing
235254
hess_sparsity = Symbolics.hessian_sparsity(_f, x)
236255
hess_colors = matrix_colors(tril(hess_sparsity))
@@ -259,6 +278,8 @@ function Optimization.instantiate_function(f, x, adtype::AutoSparseFiniteDiff, p
259278
cons = (res, θ) -> f.cons(res, θ, p)
260279
end
261280

281+
cons_jac_prototype = f.cons_jac_prototype
282+
cons_jac_colorvec = f.cons_jac_colorvec
262283
if cons !== nothing && f.cons_j === nothing
263284
cons_jac_prototype = f.cons_jac_prototype === nothing ?
264285
Symbolics.jacobian_sparsity(cons,
@@ -279,6 +300,7 @@ function Optimization.instantiate_function(f, x, adtype::AutoSparseFiniteDiff, p
279300
cons_j = (J, θ) -> f.cons_j(J, θ, p)
280301
end
281302

303+
conshess_caches = [(; sparsity = f.cons_hess_prototype, colors = f.cons_hess_colorvec)]
282304
if cons !== nothing && f.cons_h === nothing
283305
function gen_conshess_cache(_f, x)
284306
conshess_sparsity = Symbolics.hessian_sparsity(_f, x)
@@ -290,44 +312,48 @@ function Optimization.instantiate_function(f, x, adtype::AutoSparseFiniteDiff, p
290312
fcons = [(x) -> (_res = zeros(eltype(x), num_cons);
291313
cons(_res, x);
292314
_res[i]) for i in 1:num_cons]
293-
315+
conshess_caches = gen_conshess_cache.(fcons, Ref(x))
294316
cons_h = function (res, θ)
295317
for i in 1:num_cons
296-
numauto_color_hessian!(res[i], fcons[i], θ, gen_conshess_cache(fcons[i], θ))
318+
numauto_color_hessian!(res[i], fcons[i], θ, conshess_caches[i])
297319
end
298320
end
299321
else
300322
cons_h = (res, θ) -> f.cons_h(res, θ, p)
301323
end
302324

303325
if f.lag_h === nothing
304-
lag_hess_cache = FD.HessianCache(copy(x))
305-
c = zeros(num_cons)
306-
h = zeros(length(x), length(x))
307-
lag_h = let c = c, h = h
308-
lag = function (θ, σ, μ)
309-
f.cons(c, θ, p)
310-
l = μ'c
311-
if !iszero(σ)
312-
l += σ * f.f(θ, p)
313-
end
314-
l
315-
end
316-
function (res, θ, σ, μ)
317-
FD.finite_difference_hessian!(res,
318-
(x) -> lag(x, σ, μ),
319-
θ,
320-
updatecache(lag_hess_cache, θ))
321-
end
322-
end
326+
# lag_hess_cache = FD.HessianCache(copy(x))
327+
# c = zeros(num_cons)
328+
# h = zeros(length(x), length(x))
329+
# lag_h = let c = c, h = h
330+
# lag = function (θ, σ, μ)
331+
# f.cons(c, θ, p)
332+
# l = μ'c
333+
# if !iszero(σ)
334+
# l += σ * f.f(θ, p)
335+
# end
336+
# l
337+
# end
338+
# function (res, θ, σ, μ)
339+
# FD.finite_difference_hessian!(res,
340+
# (x) -> lag(x, σ, μ),
341+
# θ,
342+
# updatecache(lag_hess_cache, θ))
343+
# end
344+
# end
345+
lag_h = nothing
323346
else
324347
lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p)
325348
end
326349
return OptimizationFunction{true}(f, adtype; grad = grad, hess = hess, hv = hv,
327350
cons = cons, cons_j = cons_j, cons_h = cons_h,
328-
hess_prototype = f.hess_prototype,
329-
cons_jac_prototype = f.cons_jac_prototype,
330-
cons_hess_prototype = f.cons_hess_prototype,
351+
hess_prototype = hess_sparsity,
352+
hess_colorvec = hess_colors,
353+
cons_jac_prototype = cons_jac_prototype,
354+
cons_jac_colorvec = cons_jac_colorvec,
355+
cons_hess_prototype = getfield.(conshess_caches, :sparsity),
356+
cons_hess_colorvec = getfield.(conshess_caches, :colors),
331357
lag_h, f.lag_hess_prototype)
332358
end
333359

@@ -337,7 +363,6 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
337363
error("$(string(adtype)) with SparseDiffTools does not support functions with more than 2 arguments")
338364
end
339365
_f = (θ, args...) -> first(f.f(θ, cache.p, args...))
340-
updatecache = (cache, x) -> (cache.xmm .= x; cache.xmp .= x; cache.xpm .= x; cache.xpp .= x; return cache)
341366

342367
if f.grad === nothing
343368
gradcache = FD.GradientCache(cache.u0, cache.u0)
@@ -346,7 +371,9 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
346371
else
347372
grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...)
348373
end
349-
374+
375+
hess_sparsity = f.hess_prototype
376+
hess_colors = f.hess_colorvec
350377
if f.hess === nothing
351378
hess_sparsity = Symbolics.hessian_sparsity(_f, cache.u0)
352379
hess_colors = matrix_colors(tril(hess_sparsity))
@@ -375,6 +402,8 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
375402
cons = (res, θ) -> f.cons(res, θ, cache.p)
376403
end
377404

405+
cons_jac_prototype = f.cons_jac_prototype
406+
cons_jac_colorvec = f.cons_jac_colorvec
378407
if cons !== nothing && f.cons_j === nothing
379408
cons_jac_prototype = f.cons_jac_prototype === nothing ?
380409
Symbolics.jacobian_sparsity(cons, zeros(eltype(cache.u0), num_cons),
@@ -394,6 +423,7 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
394423
cons_j = (J, θ) -> f.cons_j(J, θ, cache.p)
395424
end
396425

426+
conshess_caches = [(; sparsity = f.cons_hess_prototype, colors = f.cons_hess_colorvec)]
397427
if cons !== nothing && f.cons_h === nothing
398428
function gen_conshess_cache(_f, x)
399429
conshess_sparsity = copy(Symbolics.hessian_sparsity(_f, x))
@@ -406,10 +436,10 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
406436
fcons = [(x) -> (_res = zeros(eltype(x), num_cons);
407437
cons(_res, x);
408438
_res[i]) for i in 1:num_cons]
409-
hesscaches = [gen_conshess_cache(fcons[i], cache.u0) for i in 1:num_cons]
439+
conshess_caches = [gen_conshess_cache(fcons[i], cache.u0) for i in 1:num_cons]
410440
cons_h = function (res, θ)
411441
for i in 1:num_cons
412-
numauto_color_hessian!(res[i], fcons[i], θ, hesscaches[i])
442+
numauto_color_hessian!(res[i], fcons[i], θ, conshess_caches[i])
413443
end
414444
end
415445
else
@@ -447,8 +477,11 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
447477
return OptimizationFunction{true}(f, adtype; grad = grad, hess = hess, hv = hv,
448478
cons = cons, cons_j = cons_j, cons_h = cons_h,
449479
hess_prototype = hess_sparsity,
480+
hess_colorvec = hess_colors,
450481
cons_jac_prototype = cons_jac_prototype,
451-
cons_hess_prototype = getfield.(hesscaches, :sparsity),
482+
cons_jac_colorvec = cons_jac_colorvec,
483+
cons_hess_prototype = getfield.(conshess_caches, :sparsity),
484+
cons_hess_colorvec = getfield.(conshess_caches, :colors),
452485
lag_h, f.lag_hess_prototype)
453486
end
454487

@@ -464,6 +497,8 @@ function Optimization.instantiate_function(f, x, adtype::AutoSparseReverseDiff,
464497
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)
465498
end
466499

500+
hess_sparsity = f.hess_prototype
501+
hess_colors = f.hess_colorvec
467502
if f.hess === nothing
468503
hess_sparsity = Symbolics.hessian_sparsity(_f, x)
469504
hess_colors = SparseDiffTools.matrix_colors(tril(hess_sparsity))
@@ -494,6 +529,8 @@ function Optimization.instantiate_function(f, x, adtype::AutoSparseReverseDiff,
494529
cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res)
495530
end
496531

532+
cons_jac_prototype = f.cons_jac_prototype
533+
cons_jac_colorvec = f.cons_jac_colorvec
497534
if cons !== nothing && f.cons_j === nothing
498535
cons_jac_prototype = Symbolics.jacobian_sparsity(cons,
499536
zeros(eltype(x), num_cons),
@@ -509,12 +546,16 @@ function Optimization.instantiate_function(f, x, adtype::AutoSparseReverseDiff,
509546
else
510547
cons_j = (J, θ) -> f.cons_j(J, θ, p)
511548
end
512-
549+
550+
conshess_sparsity = f.cons_hess_prototype
551+
conshess_colors = f.cons_hess_colorvec
513552
if cons !== nothing && f.cons_h === nothing
514-
553+
fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
554+
conshess_sparsity = Symbolics.hessian_sparsity.(fncs, Ref(x))
555+
conshess_colors = SparseDiffTools.matrix_colors.(conshess_sparsity)
515556
cons_h = function (res, θ)
516557
for i in 1:num_cons
517-
res[i] .= SparseDiffTools.forwarddiff_color_jacobian(θ, ) do θ
558+
res[i] .= SparseDiffTools.forwarddiff_color_jacobian(θ, colorvec = conshess_colors[i], sparsity = conshess_sparsity[i]) do θ
518559
ReverseDiff.gradient(fncs[i], θ)
519560
end
520561
end
@@ -530,9 +571,12 @@ function Optimization.instantiate_function(f, x, adtype::AutoSparseReverseDiff,
530571
end
531572
return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv,
532573
cons = cons, cons_j = cons_j, cons_h = cons_h,
533-
hess_prototype = f.hess_prototype,
534-
cons_jac_prototype = f.cons_jac_prototype,
535-
cons_hess_prototype = f.cons_hess_prototype,
574+
hess_prototype = hess_sparsity,
575+
hess_colorvec = hess_colors,
576+
cons_jac_prototype = cons_jac_prototype,
577+
cons_jac_colorvec = cons_jac_colorvec,
578+
cons_hess_prototype = conshess_sparsity,
579+
cons_hess_colorvec = conshess_colors,
536580
lag_h, f.lag_hess_prototype)
537581
end
538582

@@ -546,6 +590,8 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
546590
grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...)
547591
end
548592

593+
hess_sparsity = f.hess_prototype
594+
hess_colors = f.hess_colorvec
549595
if f.hess === nothing
550596
hess_sparsity = Symbolics.hessian_sparsity(_f, cache.u0)
551597
hess_colors = SparseDiffTools.matrix_colors(tril(hess_sparsity))
@@ -574,6 +620,8 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
574620
cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res)
575621
end
576622

623+
cons_jac_prototype = f.cons_jac_prototype
624+
cons_jac_colorvec = f.cons_jac_colorvec
577625
if cons !== nothing && f.cons_j === nothing
578626
cons_jac_prototype = Symbolics.jacobian_sparsity(cons,
579627
zeros(eltype(cache.u0), num_cons),
@@ -589,7 +637,9 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
589637
else
590638
cons_j = (J, θ) -> f.cons_j(J, θ, cache.p)
591639
end
592-
640+
641+
conshess_sparsity = f.cons_hess_prototype
642+
conshess_colors = f.cons_hess_colorvec
593643
if cons !== nothing && f.cons_h === nothing
594644
fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
595645
conshess_sparsity = Symbolics.hessian_sparsity.(fncs, Ref(cache.u0))
@@ -614,8 +664,11 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
614664
return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv,
615665
cons = cons, cons_j = cons_j, cons_h = cons_h,
616666
hess_prototype = hess_sparsity,
667+
hess_colorvec = hess_colors,
617668
cons_jac_prototype = cons_jac_prototype,
669+
cons_jac_colorvec = cons_jac_colorvec,
618670
cons_hess_prototype = conshess_sparsity,
671+
cons_hess_colorvec = conshess_colors,
619672
lag_h, f.lag_hess_prototype)
620673
end
621674

0 commit comments

Comments
 (0)