@@ -101,6 +101,12 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
101101 set_runtime_activity2 (Enzyme. Forward, adtype. mode)
102102 end
103103
104+ func_annot = if adtype. mode isa Nothing
105+ Nothing
106+ else
107+ adtype. mode. function_annotation
108+ end
109+
104110 if g == true && f. grad === nothing
105111 function grad (res, θ, p = p)
106112 Enzyme. make_zero! (res)
@@ -217,6 +223,14 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
217223 # if num_cons > length(x)
218224 seeds = Enzyme. onehot (x)
219225 Jaccache = Tuple (zeros (eltype (x), num_cons) for i in 1 : length (x))
226+ basefunc = f. cons
227+ if func_annot <: Enzyme.Const
228+ basefunc = Enzyme. Const (basefunc)
229+ elseif func_annot <: Enzyme.Duplicated || func_annot <: Enzyme.BatchDuplicated
230+ basefunc = Enzyme. BatchDuplicated (basefunc, Tuple (make_zero (basefunc) for i in 1 : length (x)))
231+ elseif func_annot <: Enzyme.DuplicatedNoNeed || func_annot <: Enzyme.BatchDuplicatedNoNeed
232+ basefunc = Enzyme. BatchDuplicatedNoNeed (basefunc, Tuple (make_zero (basefunc) for i in 1 : length (x)))
233+ end
220234 # else
221235 # seeds = Enzyme.onehot(zeros(eltype(x), num_cons))
222236 # Jaccache = Tuple(zero(x) for i in 1:num_cons)
@@ -225,11 +239,16 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
225239 y = zeros (eltype (x), num_cons)
226240
227241 function cons_j! (J, θ)
228- for i in eachindex ( Jaccache)
229- Enzyme. make_zero! (Jaccache[i] )
242+ for jc in Jaccache
243+ Enzyme. make_zero! (jc )
230244 end
231245 Enzyme. make_zero! (y)
232- Enzyme. autodiff (fmode, f. cons, BatchDuplicated (y, Jaccache),
246+ if func_annot <: Enzyme.Duplicated || func_annot <: Enzyme.BatchDuplicated || func_annot <: Enzyme.DuplicatedNoNeed || func_annot <: Enzyme.BatchDuplicatedNoNeed
247+ for bf in basefunc. dval
248+ Enzyme. make_zero! (bf)
249+ end
250+ end
251+ Enzyme. autodiff (fmode, basfunc, BatchDuplicated (y, Jaccache),
233252 BatchDuplicated (θ, seeds), Const (p))
234253 for i in eachindex (θ)
235254 if J isa Vector
0 commit comments