@@ -84,6 +84,8 @@ function set_runtime_activity2(
8484 a:: Mode1 , :: Enzyme.Mode{ABI, Err, RTA} ) where {Mode1, ABI, Err, RTA}
8585 Enzyme. set_runtime_activity (a, RTA)
8686end
87+ function_annotation (:: Nothing ) = Nothing
88+ function_annotation (:: AutoEnzyme{<:Any, A} ) where A = A
8789function OptimizationBase. instantiate_function (f:: OptimizationFunction{true} , x,
8890 adtype:: AutoEnzyme , p, num_cons = 0 ;
8991 g = false , h = false , hv = false , fg = false , fgh = false ,
@@ -101,6 +103,8 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
101103 set_runtime_activity2 (Enzyme. Forward, adtype. mode)
102104 end
103105
106+ func_annot = function_annotation (adtype)
107+
104108 if g == true && f. grad === nothing
105109 function grad (res, θ, p = p)
106110 Enzyme. make_zero! (res)
@@ -217,6 +221,14 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
217221 # if num_cons > length(x)
218222 seeds = Enzyme. onehot (x)
219223 Jaccache = Tuple (zeros (eltype (x), num_cons) for i in 1 : length (x))
224+ basefunc = f. cons
225+ if func_annot <: Enzyme.Const
226+ basefunc = Enzyme. Const (basefunc)
227+ elseif func_annot <: Enzyme.Duplicated || func_annot <: Enzyme.BatchDuplicated
228+ basefunc = Enzyme. BatchDuplicated (basefunc, Tuple (make_zero (basefunc) for i in 1 : length (x)))
229+ elseif func_annot <: Enzyme.DuplicatedNoNeed || func_annot <: Enzyme.BatchDuplicatedNoNeed
230+ basefunc = Enzyme. BatchDuplicatedNoNeed (basefunc, Tuple (make_zero (basefunc) for i in 1 : length (x)))
231+ end
220232 # else
221233 # seeds = Enzyme.onehot(zeros(eltype(x), num_cons))
222234 # Jaccache = Tuple(zero(x) for i in 1:num_cons)
@@ -225,11 +237,16 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
225237 y = zeros (eltype (x), num_cons)
226238
227239 function cons_j! (J, θ)
228- for i in eachindex ( Jaccache)
229- Enzyme. make_zero! (Jaccache[i] )
240+ for jc in Jaccache
241+ Enzyme. make_zero! (jc )
230242 end
231243 Enzyme. make_zero! (y)
232- Enzyme. autodiff (fmode, f. cons, BatchDuplicated (y, Jaccache),
244+ if func_annot <: Enzyme.Duplicated || func_annot <: Enzyme.BatchDuplicated || func_annot <: Enzyme.DuplicatedNoNeed || func_annot <: Enzyme.BatchDuplicatedNoNeed
245+ for bf in basefunc. dval
246+ Enzyme. make_zero! (bf)
247+ end
248+ end
249+ Enzyme. autodiff (fmode, basefunc , BatchDuplicated (y, Jaccache),
233250 BatchDuplicated (θ, seeds), Const (p))
234251 for i in eachindex (θ)
235252 if J isa Vector
0 commit comments