@@ -11,51 +11,130 @@ using Enzyme
1111# #
1212import LinearAlgebra: mul!
1313
14- function maybe_duplicated (x, :: Val{N} = Val (1 )) where {N}
15- # TODO cache?
14+ function init_cache (x)
1615 if ! Enzyme. Compiler. guaranteed_const (typeof (x))
17- if N == 1
18- return Duplicated (x, Enzyme. make_zero (x))
19- else
20- return BatchDuplicated (x, ntuple (_ -> Enzyme. make_zero (x), Val (N)))
21- end
16+ Enzyme. make_zero (x)
2217 else
18+ return nothing
19+ end
20+ end
21+
22+ function maybe_duplicated (x:: T , x′:: Union{Nothing, T} ) where {T}
23+ if x′ === nothing
2324 return Const (x)
25+ else
26+ Enzyme. remake_zero! (x′)
27+ return Duplicated (x, x′)
2428 end
2529end
2630
27- # TODO : JacobianOperator with thunk
31+ abstract type AbstractJacobianOperator end
32+
2833
2934"""
3035 JacobianOperator
3136
3237Efficient implementation of `J(f,x,p) * v` and `v * J(f, x,p)'`
3338"""
34- struct JacobianOperator{F, A, P}
39+ struct JacobianOperator{F, A, P} <: AbstractJacobianOperator
3540 f:: F # F!(res, u, p)
41+ f′:: Union{Nothing, F} # cache
3642 res:: A
3743 u:: A
3844 p:: P
45+ p′:: Union{Nothing, P} # cache
3946 function JacobianOperator (f:: F , res, u, p) where {F}
40- return new {F, typeof(u), typeof(p)} (f, res, u, p)
47+ f′ = init_cache (f)
48+ p′ = init_cache (p)
49+ return new {F, typeof(u), typeof(p)} (f, f′, res, u, p, p′)
4150 end
4251end
4352
53+ batch_size (:: JacobianOperator ) = 1
54+
4455Base. size (J:: JacobianOperator ) = (length (J. res), length (J. u))
4556Base. eltype (J:: JacobianOperator ) = eltype (J. u)
4657Base. length (J:: JacobianOperator ) = prod (size (J))
4758
48- function mul! (out:: AbstractVector , J:: JacobianOperator , v:: AbstractVector )
59+ function mul! (out, J:: JacobianOperator , v)
4960 autodiff (
5061 Forward,
51- maybe_duplicated (J. f), Const,
62+ maybe_duplicated (J. f, J . f′ ), Const,
5263 Duplicated (J. res, reshape (out, size (J. res))),
5364 Duplicated (J. u, reshape (v, size (J. u))),
54- maybe_duplicated (J. p)
65+ maybe_duplicated (J. p, J . p′ )
5566 )
5667 return nothing
5768end
5869
70+ LinearAlgebra. adjoint (J:: JacobianOperator ) = Adjoint (J)
71+ LinearAlgebra. transpose (J:: JacobianOperator ) = Transpose (J)
72+
73+ # Jᵀ(y, u) = ForwardDiff.gradient!(y, x -> dot(F(x), u), xk)
74+ # or just reverse mode
75+
76+ function mul! (out, J′:: Union{Adjoint{<:Any, <:JacobianOperator}, Transpose{<:Any, <:JacobianOperator}} , v)
77+ J = parent (J′)
78+ # TODO : provide cache for `copy(v)`
79+ # Enzyme zeros input derivatives and that confuses the solvers.
80+ # If `out` is non-zero we might get spurious gradients
81+ fill! (out, 0 )
82+ autodiff (
83+ Reverse,
84+ maybe_duplicated (J. f, J. f′), Const,
85+ Duplicated (J. res, reshape (copy (v), size (J. res))),
86+ Duplicated (J. u, reshape (out, size (J. u))),
87+ maybe_duplicated (J. p, J. p′)
88+ )
89+ return nothing
90+ end
91+
92+
93+ function init_cache (x, :: Val{N} ) where {N}
94+ if ! Enzyme. Compiler. guaranteed_const (typeof (x))
95+ return ntuple (_ -> Enzyme. make_zero (x), Val (N))
96+ else
97+ return nothing
98+ end
99+ end
100+
101+ function maybe_duplicated (x:: T , x′:: Union{Nothing, NTuple{N, T}} , :: Val{N} ) where {T, N}
102+ if x′ === nothing
103+ return Const (x)
104+ else
105+ Enzyme. remake_zero! (x′)
106+ return BatchDuplicated (x, x′)
107+ end
108+ end
109+
110+ """
111+ BatchedJacobianOperator{N}
112+
113+
114+ """
115+ struct BatchedJacobianOperator{N, F, A, P} <: AbstractJacobianOperator
116+ f:: F # F!(res, u, p)
117+ f′:: Union{Nothing, NTuple{N, F}} # cache
118+ res:: A
119+ u:: A
120+ p:: P
121+ p′:: Union{Nothing, NTuple{N, P}} # cache
122+ function BatchedJacobianOperator {N} (f:: F , res, u, p) where {F, N}
123+ f′ = init_cache (f, Val (N))
124+ p′ = init_cache (p, Val (N))
125+ return new {N, F, typeof(u), typeof(p)} (f, f′, res, u, p, p′)
126+ end
127+ end
128+
129+ batch_size (:: BatchedJacobianOperator{N} ) where {N} = N
130+
131+ Base. size (J:: BatchedJacobianOperator ) = (length (J. res), length (J. u))
132+ Base. eltype (J:: BatchedJacobianOperator ) = eltype (J. u)
133+ Base. length (J:: BatchedJacobianOperator ) = prod (size (J))
134+
135+ LinearAlgebra. adjoint (J:: BatchedJacobianOperator ) = Adjoint (J)
136+ LinearAlgebra. transpose (J:: BatchedJacobianOperator ) = Transpose (J)
137+
59138if VERSION >= v " 1.11.0"
60139
61140 function tuple_of_vectors (M:: Matrix{T} , shape) where {T}
@@ -66,49 +145,23 @@ if VERSION >= v"1.11.0"
66145 end
67146 end
68147
69- function mul! (Out:: AbstractMatrix , J:: JacobianOperator , V:: AbstractMatrix )
148+ function mul! (Out, J:: BatchedJacobianOperator{N} , V) where {N}
70149 @assert size (Out, 2 ) == size (V, 2 )
71150 out = tuple_of_vectors (Out, size (J. res))
72151 v = tuple_of_vectors (V, size (J. u))
73152
74- N = length (out)
153+ @assert N = = length (out)
75154 autodiff (
76155 Forward,
77- maybe_duplicated (J. f, Val (N)), Const,
156+ maybe_duplicated (J. f, J . f′, Val (N)), Const,
78157 BatchDuplicated (J. res, out),
79158 BatchDuplicated (J. u, v),
80- maybe_duplicated (J. p, Val (N))
159+ maybe_duplicated (J. p, J . p′, Val (N))
81160 )
82161 return nothing
83162 end
84163
85- end # VERSION >= v"1.11.0"
86-
87- LinearAlgebra. adjoint (J:: JacobianOperator ) = Adjoint (J)
88- LinearAlgebra. transpose (J:: JacobianOperator ) = Transpose (J)
89-
90- # Jᵀ(y, u) = ForwardDiff.gradient!(y, x -> dot(F(x), u), xk)
91- # or just reverse mode
92-
93- function mul! (out:: AbstractVector , J′:: Union{Adjoint{<:Any, <:JacobianOperator}, Transpose{<:Any, <:JacobianOperator}} , v:: AbstractVector )
94- J = parent (J′)
95- # TODO : provide cache for `copy(v)`
96- # Enzyme zeros input derivatives and that confuses the solvers.
97- # If `out` is non-zero we might get spurious gradients
98- fill! (out, 0 )
99- autodiff (
100- Reverse,
101- maybe_duplicated (J. f), Const,
102- Duplicated (J. res, reshape (copy (v), size (J. res))),
103- Duplicated (J. u, reshape (out, size (J. u))),
104- maybe_duplicated (J. p)
105- )
106- return nothing
107- end
108-
109- if VERSION >= v " 1.11.0"
110-
111- function mul! (Out:: AbstractMatrix , J′:: Union{Adjoint{<:Any, <:JacobianOperator}, Transpose{<:Any, <:JacobianOperator}} , V:: AbstractMatrix )
164+ function mul! (Out, J′:: Union{Adjoint{<:Any, <:BatchedJacobianOperator{N}}, Transpose{<:Any, <:BatchedJacobianOperator{N}}} , V) where {N}
112165 J = parent (J′)
113166 @assert size (Out, 2 ) == size (V, 2 )
114167
@@ -122,22 +175,20 @@ if VERSION >= v"1.11.0"
122175 out = tuple_of_vectors (Out, size (J. u))
123176 v = tuple_of_vectors (V, size (J. res))
124177
125- N = length (out)
178+ @assert N = = length (out)
126179
127- # TODO : BatchDuplicated for J.f
128180 autodiff (
129181 Reverse,
130- maybe_duplicated (J. f, Val (N)), Const,
182+ maybe_duplicated (J. f, J . f′, Val (N)), Const,
131183 BatchDuplicated (J. res, v),
132184 BatchDuplicated (J. u, out),
133- maybe_duplicated (J. p, Val (N))
185+ maybe_duplicated (J. p, J . p′, Val (N))
134186 )
135187 return nothing
136188 end
137-
138189end # VERSION >= v"1.11.0"
139190
140- function Base. collect (JOp:: Union{Adjoint{<:Any, <:JacobianOperator }, Transpose{<:Any, <:JacobianOperator }, JacobianOperator } )
191+ function Base. collect (JOp:: Union{Adjoint{<:Any, <:AbstractJacobianOperator }, Transpose{<:Any, <:AbstractJacobianOperator }, AbstractJacobianOperator } )
141192 N, M = size (JOp)
142193 if JOp isa JacobianOperator
143194 v = zero (JOp. u)
0 commit comments