@@ -13,7 +13,7 @@ import LinearAlgebra: mul!
1313
1414function maybe_duplicated (f, df)
1515 if ! Enzyme. Compiler. guaranteed_const (typeof (f))
16- return DuplicatedNoNeed (f, df)
16+ return Duplicated (f, df)
1717 else
1818 return Const (f)
1919 end
3737
3838Base. size (J:: JacobianOperator ) = (length (J. res), length (J. u))
3939Base. eltype (J:: JacobianOperator ) = eltype (J. u)
40+ Base. length (J:: JacobianOperator ) = prod (size (J))
4041
4142function mul! (out, J:: JacobianOperator , v)
4243 # Enzyme.make_zero!(J.f_cache)
4344 f_cache = Enzyme. make_zero (J. f) # Stop gap until we can zero out mutable values
4445 autodiff (
4546 Forward,
4647 maybe_duplicated (J. f, f_cache), Const,
47- DuplicatedNoNeed (J. res, reshape (out, size (J. res))),
48- DuplicatedNoNeed (J. u, reshape (v, size (J. u)))
48+ Duplicated (J. res, reshape (out, size (J. res))),
49+ Duplicated (J. u, reshape (v, size (J. u)))
4950 )
5051 return nothing
5152end
@@ -61,16 +62,18 @@ function mul!(out, J′::Union{Adjoint{<:Any, <:JacobianOperator}, Transpose{<:A
6162 Enzyme. make_zero! (J. f_cache)
6263 # TODO : provide cache for `copy(v)`
6364 # Enzyme zeros input derivatives and that confuses the solvers.
65+ # If `out` is non-zero we might get spurious gradients
66+ fill! (out, 0 )
6467 autodiff (
6568 Reverse,
6669 maybe_duplicated (J. f, J. f_cache), Const,
67- DuplicatedNoNeed (J. res, reshape (copy (v), size (J. res))),
68- DuplicatedNoNeed (J. u, reshape (out, size (J. u)))
70+ Duplicated (J. res, reshape (copy (v), size (J. res))),
71+ Duplicated (J. u, reshape (out, size (J. u)))
6972 )
7073 return nothing
7174end
7275
73- function Base. collect (JOp:: JacobianOperator )
76+ function Base. collect (JOp:: Union{Adjoint{<:Any, <: JacobianOperator}, Transpose{<:Any, <:JacobianOperator}, JacobianOperator} )
7477 N, M = size (JOp)
7578 v = zeros (eltype (JOp), M)
7679 out = zeros (eltype (JOp), N)
0 commit comments