Skip to content
This repository was archived by the owner on Aug 25, 2025. It is now read-only.

Commit 75fc2a3

Browse files
committed
Enzyme: add func_annotation
1 parent 6d96450 commit 75fc2a3

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

ext/OptimizationEnzymeExt.jl

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,12 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
101101
set_runtime_activity2(Enzyme.Forward, adtype.mode)
102102
end
103103

104+
func_annot = if adtype.mode isa Nothing
105+
Nothing
106+
else
107+
adtype.mode.function_annotation
108+
end
109+
104110
if g == true && f.grad === nothing
105111
function grad(res, θ, p = p)
106112
Enzyme.make_zero!(res)
@@ -217,6 +223,14 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
217223
# if num_cons > length(x)
218224
seeds = Enzyme.onehot(x)
219225
Jaccache = Tuple(zeros(eltype(x), num_cons) for i in 1:length(x))
226+
basefunc = f.cons
227+
if func_annot <: Enzyme.Const
228+
basefunc = Enzyme.Const(basefunc)
229+
elseif func_annot <: Enzyme.Duplicated || func_annot <: Enzyme.BatchDuplicated
230+
basefunc = Enzyme.BatchDuplicated(basefunc, Tuple(make_zero(basefunc) for i in 1:length(x)))
231+
elseif func_annot <: Enzyme.DuplicatedNoNeed || func_annot <: Enzyme.BatchDuplicatedNoNeed
232+
basefunc = Enzyme.BatchDuplicatedNoNeed(basefunc, Tuple(make_zero(basefunc) for i in 1:length(x)))
233+
end
220234
# else
221235
# seeds = Enzyme.onehot(zeros(eltype(x), num_cons))
222236
# Jaccache = Tuple(zero(x) for i in 1:num_cons)
@@ -225,11 +239,16 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
225239
y = zeros(eltype(x), num_cons)
226240

227241
function cons_j!(J, θ)
228-
for i in eachindex(Jaccache)
229-
Enzyme.make_zero!(Jaccache[i])
242+
for jc in Jaccache
243+
Enzyme.make_zero!(jc)
230244
end
231245
Enzyme.make_zero!(y)
232-
Enzyme.autodiff(fmode, f.cons, BatchDuplicated(y, Jaccache),
246+
if func_annot <: Enzyme.Duplicated || func_annot <: Enzyme.BatchDuplicated || func_annot <: Enzyme.DuplicatedNoNeed || func_annot <: Enzyme.BatchDuplicatedNoNeed
247+
for bf in basefunc.dval
248+
Enzyme.make_zero!(bf)
249+
end
250+
end
251+
Enzyme.autodiff(fmode, basfunc, BatchDuplicated(y, Jaccache),
233252
BatchDuplicated(θ, seeds), Const(p))
234253
for i in eachindex(θ)
235254
if J isa Vector

0 commit comments

Comments
 (0)