Skip to content

Commit ea85816

Browse files
authored
add batched JacobianOperator (#15)
1 parent d3e44c4 commit ea85816

File tree

2 files changed

+72
-2
lines changed

2 files changed

+72
-2
lines changed

src/NewtonKrylov.jl

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ function maybe_duplicated(f, df)
1919
end
2020
end
2121

22+
# TODO: JacobianOperator with thunk
23+
2224
"""
2325
JacobianOperator
2426
@@ -39,7 +41,7 @@ Base.size(J::JacobianOperator) = (length(J.res), length(J.u))
3941
Base.eltype(J::JacobianOperator) = eltype(J.u)
4042
Base.length(J::JacobianOperator) = prod(size(J))
4143

42-
function mul!(out, J::JacobianOperator, v)
44+
function mul!(out::AbstractVector, J::JacobianOperator, v::AbstractVector)
4345
# Enzyme.make_zero!(J.f_cache)
4446
f_cache = Enzyme.make_zero(J.f) # Stop gap until we can zero out mutable values
4547
autodiff(
@@ -51,13 +53,41 @@ function mul!(out, J::JacobianOperator, v)
5153
return nothing
5254
end
5355

56+
if VERSION >= v"1.11.0"
57+
58+
function tuple_of_vectors(M::Matrix{T}, shape) where {T}
59+
n, m = size(M)
60+
return ntuple(m) do i
61+
vec = Base.wrap(Array, memoryref(M.ref, (i - 1) * n + 1), (n,))
62+
reshape(vec, shape)
63+
end
64+
end
65+
66+
function mul!(Out::AbstractMatrix, J::JacobianOperator, V::AbstractMatrix)
67+
@assert size(Out, 2) == size(V, 2)
68+
out = tuple_of_vectors(Out, size(J.res))
69+
v = tuple_of_vectors(V, size(J.u))
70+
71+
# f_cache = Enzyme.make_zero(J.f)
72+
# TODO: BatchDuplicated for J.f
73+
autodiff(
74+
Forward,
75+
Const(J.f), Const,
76+
BatchDuplicated(J.res, out),
77+
BatchDuplicated(J.u, v)
78+
)
79+
return nothing
80+
end
81+
82+
end # VERSION >= v"1.11.0"
83+
5484
LinearAlgebra.adjoint(J::JacobianOperator) = Adjoint(J)
5585
LinearAlgebra.transpose(J::JacobianOperator) = Transpose(J)
5686

5787
# Jᵀ(y, u) = ForwardDiff.gradient!(y, x -> dot(F(x), u), xk)
5888
# or just reverse mode
5989

60-
function mul!(out, J′::Union{Adjoint{<:Any, <:JacobianOperator}, Transpose{<:Any, <:JacobianOperator}}, v)
90+
function mul!(out::AbstractVector, J′::Union{Adjoint{<:Any, <:JacobianOperator}, Transpose{<:Any, <:JacobianOperator}}, v::AbstractVector)
6191
J = parent(J′)
6292
Enzyme.make_zero!(J.f_cache)
6393
# TODO: provide cache for `copy(v)`
@@ -73,6 +103,34 @@ function mul!(out, J′::Union{Adjoint{<:Any, <:JacobianOperator}, Transpose{<:A
73103
return nothing
74104
end
75105

106+
if VERSION >= v"1.11.0"
107+
108+
function mul!(Out::AbstractMatrix, J′::Union{Adjoint{<:Any, <:JacobianOperator}, Transpose{<:Any, <:JacobianOperator}}, V::AbstractMatrix)
109+
J = parent(J′)
110+
@assert size(Out, 2) == size(V, 2)
111+
112+
# If `out` is non-zero we might get spurious gradients
113+
fill!(Out, 0)
114+
115+
# TODO: provide cache for `copy(v)`
116+
# Enzyme zeros input derivatives and that confuses the solvers.
117+
V = copy(V)
118+
119+
out = tuple_of_vectors(Out, size(J.u))
120+
v = tuple_of_vectors(V, size(J.res))
121+
122+
# TODO: BatchDuplicated for J.f
123+
autodiff(
124+
Reverse,
125+
Const(J.f), Const,
126+
BatchDuplicated(J.res, v),
127+
BatchDuplicated(J.u, out)
128+
)
129+
return nothing
130+
end
131+
132+
end # VERSION >= v"1.11.0"
133+
76134
function Base.collect(JOp::Union{Adjoint{<:Any, <:JacobianOperator}, Transpose{<:Any, <:JacobianOperator}, JacobianOperator})
77135
N, M = size(JOp)
78136
v = zeros(eltype(JOp), M)

test/runtests.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,16 @@ using Enzyme, LinearAlgebra
5252
@test out J_Enz * v
5353

5454
@test collect(transpose(J)) == transpose(collect(J))
55+
56+
# Batched
57+
if VERSION >= v"1.11.0"
58+
V = [1.0 0.0; 0.0 1.0]
59+
Out = similar(V)
60+
mul!(Out, J, V)
61+
62+
@test Out == J_Enz
63+
64+
mul!(Out, transpose(J), V)
65+
@test Out == collect(transpose(J))
66+
end
5567
end

0 commit comments

Comments
 (0)