Skip to content

Commit 62a8197

Browse files
Merge pull request #595 from SciML/fixchunksize
Fix chunksize issue
2 parents 1167183 + b82d7c3 commit 62a8197

File tree

2 files changed

+26
-12
lines changed

2 files changed

+26
-12
lines changed

ext/OptimizationReverseDiffExt.jl

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,20 @@ isdefined(Base, :get_extension) ? (using ReverseDiff, ReverseDiff.ForwardDiff) :
99

1010
struct OptimizationReverseDiffTag end
1111

12+
function default_chunk_size(len)
13+
if len < ForwardDiff.DEFAULT_CHUNK_THRESHOLD
14+
len
15+
else
16+
ForwardDiff.DEFAULT_CHUNK_THRESHOLD
17+
end
18+
end
19+
1220
function Optimization.instantiate_function(f, x, adtype::AutoReverseDiff,
1321
p = SciMLBase.NullParameters(),
1422
num_cons = 0)
1523
_f = (θ, args...) -> first(f.f(θ, p, args...))
24+
25+
chunksize = default_chunk_size(length(x))
1626

1727
if f.grad === nothing
1828
if adtype.compile
@@ -32,14 +42,14 @@ function Optimization.instantiate_function(f, x, adtype::AutoReverseDiff,
3242
if f.hess === nothing
3343
if adtype.compile
3444
T = ForwardDiff.Tag(OptimizationReverseDiffTag(),eltype(x))
35-
xdual = ForwardDiff.Dual{typeof(T),eltype(x),length(x)}.(x, Ref(ForwardDiff.Partials((ones(eltype(x), length(x))...,))))
45+
xdual = ForwardDiff.Dual{typeof(T),eltype(x),chunksize}.(x, Ref(ForwardDiff.Partials((ones(eltype(x), chunksize)...,))))
3646
h_tape = ReverseDiff.GradientTape(_f, xdual)
3747
htape = ReverseDiff.compile(h_tape)
3848
function g(θ)
3949
res1 = zeros(eltype(θ), length(θ))
4050
ReverseDiff.gradient!(res1, htape, θ)
4151
end
42-
jaccfg = ForwardDiff.JacobianConfig(g, x, ForwardDiff.Chunk(x), T)
52+
jaccfg = ForwardDiff.JacobianConfig(g, x, ForwardDiff.Chunk{chunksize}(), T)
4353
hess = function (res, θ, args...)
4454
ForwardDiff.jacobian!(res, g, θ, jaccfg, Val{false}())
4555
end
@@ -100,7 +110,7 @@ function Optimization.instantiate_function(f, x, adtype::AutoReverseDiff,
100110
ReverseDiff.gradient!(res1, htape, θ)
101111
end
102112
gs = [x -> grad_cons(x, conshtapes[i]) for i in 1:num_cons]
103-
jaccfgs = [ForwardDiff.JacobianConfig(gs[i], x, ForwardDiff.Chunk(x), T) for i in 1:num_cons]
113+
jaccfgs = [ForwardDiff.JacobianConfig(gs[i], x, ForwardDiff.Chunk{chunksize}(), T) for i in 1:num_cons]
104114
cons_h = function (res, θ)
105115
for i in 1:num_cons
106116
ForwardDiff.jacobian!(res[i], gs[i], θ, jaccfgs[i], Val{false}())
@@ -134,6 +144,8 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
134144
adtype::AutoReverseDiff, num_cons = 0)
135145
_f = (θ, args...) -> first(f.f(θ, cache.p, args...))
136146

147+
chunksize = default_chunk_size(length(cache.u0))
148+
137149
if f.grad === nothing
138150
if adtype.compile
139151
_tape = ReverseDiff.GradientTape(_f, cache.u0)
@@ -152,14 +164,14 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
152164
if f.hess === nothing
153165
if adtype.compile
154166
T = ForwardDiff.Tag(OptimizationReverseDiffTag(),eltype(cache.u0))
155-
xdual = ForwardDiff.Dual{typeof(T),eltype(cache.u0),length(cache.u0)}.(cache.u0, Ref(ForwardDiff.Partials((ones(eltype(cache.u0), length(cache.u0))...,))))
167+
xdual = ForwardDiff.Dual{typeof(T),eltype(cache.u0),chunksize}.(cache.u0, Ref(ForwardDiff.Partials((ones(eltype(cache.u0), chunksize)...,))))
156168
h_tape = ReverseDiff.GradientTape(_f, xdual)
157169
htape = ReverseDiff.compile(h_tape)
158170
function g(θ)
159171
res1 = zeros(eltype(θ), length(θ))
160172
ReverseDiff.gradient!(res1, htape, θ)
161173
end
162-
jaccfg = ForwardDiff.JacobianConfig(g, cache.u0, ForwardDiff.Chunk(cache.u0), T)
174+
jaccfg = ForwardDiff.JacobianConfig(g, cache.u0, ForwardDiff.Chunk{chunksize}(), T)
163175
hess = function (res, θ, args...)
164176
ForwardDiff.jacobian!(res, g, θ, jaccfg, Val{false}())
165177
end
@@ -220,7 +232,7 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
220232
ReverseDiff.gradient!(res1, htape, θ)
221233
end
222234
gs = [x -> grad_cons(x, conshtapes[i]) for i in 1:num_cons]
223-
jaccfgs = [ForwardDiff.JacobianConfig(gs[i], cache.u0, ForwardDiff.Chunk(cache.u0), T) for i in 1:num_cons]
235+
jaccfgs = [ForwardDiff.JacobianConfig(gs[i], cache.u0, ForwardDiff.Chunk{chunksize}(), T) for i in 1:num_cons]
224236
cons_h = function (res, θ)
225237
for i in 1:num_cons
226238
ForwardDiff.jacobian!(res[i], gs[i], θ, jaccfgs[i], Val{false}())

ext/OptimizationSparseDiffExt.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,8 @@ function Optimization.instantiate_function(f, x, adtype::AutoSparseReverseDiff,
492492
num_cons = 0)
493493
_f = (θ, args...) -> first(f.f(θ, p, args...))
494494

495+
chunksize = default_chunk_size(length(x))
496+
495497
if f.grad === nothing
496498
if adtype.compile
497499
_tape = ReverseDiff.GradientTape(_f, x)
@@ -514,7 +516,7 @@ function Optimization.instantiate_function(f, x, adtype::AutoSparseReverseDiff,
514516
hess_colors = SparseDiffTools.matrix_colors(tril(hess_sparsity))
515517
if adtype.compile
516518
T = ForwardDiff.Tag(OptimizationSparseReverseTag(),eltype(x))
517-
xdual = ForwardDiff.Dual{typeof(T),eltype(x),length(x)}.(x, Ref(ForwardDiff.Partials((ones(eltype(x), length(x))...,))))
519+
xdual = ForwardDiff.Dual{typeof(T),eltype(x),min(chunksize, maximum(hess_colors))}.(x, Ref(ForwardDiff.Partials((ones(eltype(x), min(chunksize, maximum(hess_colors)))...,))))
518520
h_tape = ReverseDiff.GradientTape(_f, xdual)
519521
htape = ReverseDiff.compile(h_tape)
520522
function g(res1, θ)
@@ -582,15 +584,14 @@ function Optimization.instantiate_function(f, x, adtype::AutoSparseReverseDiff,
582584
conshess_colors = SparseDiffTools.matrix_colors.(conshess_sparsity)
583585
if adtype.compile
584586
T = ForwardDiff.Tag(OptimizationSparseReverseTag(),eltype(x))
585-
xduals = [ForwardDiff.Dual{typeof(T),eltype(x),maximum(conshess_colors[i])}.(x, Ref(ForwardDiff.Partials((ones(eltype(x), maximum(conshess_colors[i]))...,)))) for i in 1:num_cons]
587+
xduals = [ForwardDiff.Dual{typeof(T),eltype(x),min(chunksize, maximum(conshess_colors[i]))}.(x, Ref(ForwardDiff.Partials((ones(eltype(x), min(chunksize, maximum(conshess_colors[i])))...,)))) for i in 1:num_cons]
586588
consh_tapes = [ReverseDiff.GradientTape(fncs[i], xduals[i]) for i in 1:num_cons]
587589
conshtapes = ReverseDiff.compile.(consh_tapes)
588590
function grad_cons(res1, θ, htape)
589591
ReverseDiff.gradient!(res1, htape, θ)
590592
end
591593
gs = [(res1, x) -> grad_cons(res1, x, conshtapes[i]) for i in 1:num_cons]
592594
jaccfgs = [ForwardColorJacCache(gs[i], x; tag = typeof(T), colorvec = conshess_colors[i], sparsity = conshess_sparsity[i]) for i in 1:num_cons]
593-
println(jaccfgs)
594595
cons_h = function (res, θ)
595596
for i in 1:num_cons
596597
SparseDiffTools.forwarddiff_color_jacobian!(res[i], gs[i], θ, jaccfgs[i])
@@ -629,6 +630,8 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
629630
adtype::AutoSparseReverseDiff, num_cons = 0)
630631
_f = (θ, args...) -> first(f.f(θ, cache.p, args...))
631632

633+
chunksize = default_chunk_size(length(cache.u0))
634+
632635
if f.grad === nothing
633636
if adtype.compile
634637
_tape = ReverseDiff.GradientTape(_f, cache.u0)
@@ -651,7 +654,7 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
651654
hess_colors = SparseDiffTools.matrix_colors(tril(hess_sparsity))
652655
if adtype.compile
653656
T = ForwardDiff.Tag(OptimizationSparseReverseTag(),eltype(cache.u0))
654-
xdual = ForwardDiff.Dual{typeof(T),eltype(cache.u0),length(cache.u0)}.(cache.u0, Ref(ForwardDiff.Partials((ones(eltype(cache.u0), length(cache.u0))...,))))
657+
xdual = ForwardDiff.Dual{typeof(T),eltype(cache.u0),min(chunksize, maximum(hess_colors))}.(cache.u0, Ref(ForwardDiff.Partials((ones(eltype(cache.u0), min(chunksize, maximum(hess_colors)))...,))))
655658
h_tape = ReverseDiff.GradientTape(_f, xdual)
656659
htape = ReverseDiff.compile(h_tape)
657660
function g(res1, θ)
@@ -719,15 +722,14 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
719722
conshess_colors = SparseDiffTools.matrix_colors.(conshess_sparsity)
720723
if adtype.compile
721724
T = ForwardDiff.Tag(OptimizationSparseReverseTag(),eltype(cache.u0))
722-
xduals = [ForwardDiff.Dual{typeof(T),eltype(cache.u0),maximum(conshess_colors[i])}.(cache.u0, Ref(ForwardDiff.Partials((ones(eltype(cache.u0), maximum(conshess_colors[i]))...,)))) for i in 1:num_cons]
725+
xduals = [ForwardDiff.Dual{typeof(T),eltype(cache.u0),min(chunksize, maximum(conshess_colors[i]))}.(cache.u0, Ref(ForwardDiff.Partials((ones(eltype(cache.u0), min(chunksize, maximum(conshess_colors[i])))...,)))) for i in 1:num_cons]
723726
consh_tapes = [ReverseDiff.GradientTape(fncs[i], xduals[i]) for i in 1:num_cons]
724727
conshtapes = ReverseDiff.compile.(consh_tapes)
725728
function grad_cons(res1, θ, htape)
726729
ReverseDiff.gradient!(res1, htape, θ)
727730
end
728731
gs = [(res1, x) -> grad_cons(res1, x, conshtapes[i]) for i in 1:num_cons]
729732
jaccfgs = [ForwardColorJacCache(gs[i], cache.u0; tag = typeof(T), colorvec = conshess_colors[i], sparsity = conshess_sparsity[i]) for i in 1:num_cons]
730-
println(jaccfgs)
731733
cons_h = function (res, θ)
732734
for i in 1:num_cons
733735
SparseDiffTools.forwarddiff_color_jacobian!(res[i], gs[i], θ, jaccfgs[i])

0 commit comments

Comments
 (0)