@@ -19,6 +19,8 @@ function maybe_duplicated(f, df)
1919 end
2020end
2121
22+ # TODO : JacobianOperator with thunk
23+
2224"""
2325 JacobianOperator
2426
@@ -39,7 +41,7 @@ Base.size(J::JacobianOperator) = (length(J.res), length(J.u))
3941Base. eltype (J:: JacobianOperator ) = eltype (J. u)
4042Base. length (J:: JacobianOperator ) = prod (size (J))
4143
42- function mul! (out, J:: JacobianOperator , v)
44+ function mul! (out:: AbstractVector , J:: JacobianOperator , v:: AbstractVector )
4345 # Enzyme.make_zero!(J.f_cache)
4446 f_cache = Enzyme. make_zero (J. f) # Stop gap until we can zero out mutable values
4547 autodiff (
@@ -51,13 +53,41 @@ function mul!(out, J::JacobianOperator, v)
5153 return nothing
5254end
5355
56+ if VERSION >= v " 1.11.0"
57+
58+ function tuple_of_vectors (M:: Matrix{T} , shape) where {T}
59+ n, m = size (M)
60+ return ntuple (m) do i
61+ vec = Base. wrap (Array, memoryref (M. ref, (i - 1 ) * n + 1 ), (n,))
62+ reshape (vec, shape)
63+ end
64+ end
65+
66+ function mul! (Out:: AbstractMatrix , J:: JacobianOperator , V:: AbstractMatrix )
67+ @assert size (Out, 2 ) == size (V, 2 )
68+ out = tuple_of_vectors (Out, size (J. res))
69+ v = tuple_of_vectors (V, size (J. u))
70+
71+ # f_cache = Enzyme.make_zero(J.f)
72+ # TODO : BatchDuplicated for J.f
73+ autodiff (
74+ Forward,
75+ Const (J. f), Const,
76+ BatchDuplicated (J. res, out),
77+ BatchDuplicated (J. u, v)
78+ )
79+ return nothing
80+ end
81+
82+ end # VERSION >= v"1.11.0"
83+
5484LinearAlgebra. adjoint (J:: JacobianOperator ) = Adjoint (J)
5585LinearAlgebra. transpose (J:: JacobianOperator ) = Transpose (J)
5686
5787# Jᵀ(y, u) = ForwardDiff.gradient!(y, x -> dot(F(x), u), xk)
5888# or just reverse mode
5989
60- function mul! (out, J′:: Union{Adjoint{<:Any, <:JacobianOperator}, Transpose{<:Any, <:JacobianOperator}} , v)
90+ function mul! (out:: AbstractVector , J′:: Union{Adjoint{<:Any, <:JacobianOperator}, Transpose{<:Any, <:JacobianOperator}} , v:: AbstractVector )
6191 J = parent (J′)
6292 Enzyme. make_zero! (J. f_cache)
6393 # TODO : provide cache for `copy(v)`
@@ -73,6 +103,34 @@ function mul!(out, J′::Union{Adjoint{<:Any, <:JacobianOperator}, Transpose{<:A
73103 return nothing
74104end
75105
106+ if VERSION >= v " 1.11.0"
107+
108+ function mul! (Out:: AbstractMatrix , J′:: Union{Adjoint{<:Any, <:JacobianOperator}, Transpose{<:Any, <:JacobianOperator}} , V:: AbstractMatrix )
109+ J = parent (J′)
110+ @assert size (Out, 2 ) == size (V, 2 )
111+
112+ # If `out` is non-zero we might get spurious gradients
113+ fill! (Out, 0 )
114+
115+ # TODO : provide cache for `copy(v)`
116+ # Enzyme zeros input derivatives and that confuses the solvers.
117+ V = copy (V)
118+
119+ out = tuple_of_vectors (Out, size (J. u))
120+ v = tuple_of_vectors (V, size (J. res))
121+
122+ # TODO : BatchDuplicated for J.f
123+ autodiff (
124+ Reverse,
125+ Const (J. f), Const,
126+ BatchDuplicated (J. res, v),
127+ BatchDuplicated (J. u, out)
128+ )
129+ return nothing
130+ end
131+
132+ end # VERSION >= v"1.11.0"
133+
76134function Base. collect (JOp:: Union{Adjoint{<:Any, <:JacobianOperator}, Transpose{<:Any, <:JacobianOperator}, JacobianOperator} )
77135 N, M = size (JOp)
78136 v = zeros (eltype (JOp), M)
0 commit comments