Skip to content

Commit 25defdf

Browse files
authored
Use remake_zero! to implement caching for shadows (#33)
1 parent c6ec180 commit 25defdf

File tree

4 files changed

+108
-53
lines changed

4 files changed

+108
-53
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1010
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1111

1212
[compat]
13-
Enzyme = "0.13"
13+
Enzyme = "0.13.50"
1414
Krylov = "0.10.1"
1515
LinearAlgebra = "1.10"
1616
julia = "1.10"

docs/src/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ Ariadne.EisenstatWalker
2121

2222
```@docs
2323
Ariadne.JacobianOperator
24+
Ariadne.BatchedJacobianOperator
2425
```
2526

2627
## Bibliography

src/Ariadne.jl

Lines changed: 101 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -11,51 +11,130 @@ using Enzyme
1111
##
1212
import LinearAlgebra: mul!
1313

14-
function maybe_duplicated(x, ::Val{N} = Val(1)) where {N}
15-
# TODO cache?
14+
function init_cache(x)
1615
if !Enzyme.Compiler.guaranteed_const(typeof(x))
17-
if N == 1
18-
return Duplicated(x, Enzyme.make_zero(x))
19-
else
20-
return BatchDuplicated(x, ntuple(_ -> Enzyme.make_zero(x), Val(N)))
21-
end
16+
Enzyme.make_zero(x)
2217
else
18+
return nothing
19+
end
20+
end
21+
22+
function maybe_duplicated(x::T, x′::Union{Nothing, T}) where {T}
23+
if x′ === nothing
2324
return Const(x)
25+
else
26+
Enzyme.remake_zero!(x′)
27+
return Duplicated(x, x′)
2428
end
2529
end
2630

27-
# TODO: JacobianOperator with thunk
31+
abstract type AbstractJacobianOperator end
32+
2833

2934
"""
3035
JacobianOperator
3136
3237
Efficient implementation of `J(f,x,p) * v` and `v * J(f, x,p)'`
3338
"""
34-
struct JacobianOperator{F, A, P}
39+
struct JacobianOperator{F, A, P} <: AbstractJacobianOperator
3540
f::F # F!(res, u, p)
41+
f′::Union{Nothing, F} # cache
3642
res::A
3743
u::A
3844
p::P
45+
p′::Union{Nothing, P} # cache
3946
function JacobianOperator(f::F, res, u, p) where {F}
40-
return new{F, typeof(u), typeof(p)}(f, res, u, p)
47+
f′ = init_cache(f)
48+
p′ = init_cache(p)
49+
return new{F, typeof(u), typeof(p)}(f, f′, res, u, p, p′)
4150
end
4251
end
4352

53+
batch_size(::JacobianOperator) = 1
54+
4455
Base.size(J::JacobianOperator) = (length(J.res), length(J.u))
4556
Base.eltype(J::JacobianOperator) = eltype(J.u)
4657
Base.length(J::JacobianOperator) = prod(size(J))
4758

48-
function mul!(out::AbstractVector, J::JacobianOperator, v::AbstractVector)
59+
function mul!(out, J::JacobianOperator, v)
4960
autodiff(
5061
Forward,
51-
maybe_duplicated(J.f), Const,
62+
maybe_duplicated(J.f, J.f′), Const,
5263
Duplicated(J.res, reshape(out, size(J.res))),
5364
Duplicated(J.u, reshape(v, size(J.u))),
54-
maybe_duplicated(J.p)
65+
maybe_duplicated(J.p, J.p′)
5566
)
5667
return nothing
5768
end
5869

70+
LinearAlgebra.adjoint(J::JacobianOperator) = Adjoint(J)
71+
LinearAlgebra.transpose(J::JacobianOperator) = Transpose(J)
72+
73+
# Jᵀ(y, u) = ForwardDiff.gradient!(y, x -> dot(F(x), u), xk)
74+
# or just reverse mode
75+
76+
function mul!(out, J′::Union{Adjoint{<:Any, <:JacobianOperator}, Transpose{<:Any, <:JacobianOperator}}, v)
77+
J = parent(J′)
78+
# TODO: provide cache for `copy(v)`
79+
# Enzyme zeros input derivatives and that confuses the solvers.
80+
# If `out` is non-zero we might get spurious gradients
81+
fill!(out, 0)
82+
autodiff(
83+
Reverse,
84+
maybe_duplicated(J.f, J.f′), Const,
85+
Duplicated(J.res, reshape(copy(v), size(J.res))),
86+
Duplicated(J.u, reshape(out, size(J.u))),
87+
maybe_duplicated(J.p, J.p′)
88+
)
89+
return nothing
90+
end
91+
92+
93+
function init_cache(x, ::Val{N}) where {N}
94+
if !Enzyme.Compiler.guaranteed_const(typeof(x))
95+
return ntuple(_ -> Enzyme.make_zero(x), Val(N))
96+
else
97+
return nothing
98+
end
99+
end
100+
101+
function maybe_duplicated(x::T, x′::Union{Nothing, NTuple{N, T}}, ::Val{N}) where {T, N}
102+
if x′ === nothing
103+
return Const(x)
104+
else
105+
Enzyme.remake_zero!(x′)
106+
return BatchDuplicated(x, x′)
107+
end
108+
end
109+
110+
"""
111+
BatchedJacobianOperator{N}
112+
113+
114+
"""
115+
struct BatchedJacobianOperator{N, F, A, P} <: AbstractJacobianOperator
116+
f::F # F!(res, u, p)
117+
f′::Union{Nothing, NTuple{N, F}} # cache
118+
res::A
119+
u::A
120+
p::P
121+
p′::Union{Nothing, NTuple{N, P}} # cache
122+
function BatchedJacobianOperator{N}(f::F, res, u, p) where {F, N}
123+
f′ = init_cache(f, Val(N))
124+
p′ = init_cache(p, Val(N))
125+
return new{N, F, typeof(u), typeof(p)}(f, f′, res, u, p, p′)
126+
end
127+
end
128+
129+
batch_size(::BatchedJacobianOperator{N}) where {N} = N
130+
131+
Base.size(J::BatchedJacobianOperator) = (length(J.res), length(J.u))
132+
Base.eltype(J::BatchedJacobianOperator) = eltype(J.u)
133+
Base.length(J::BatchedJacobianOperator) = prod(size(J))
134+
135+
LinearAlgebra.adjoint(J::BatchedJacobianOperator) = Adjoint(J)
136+
LinearAlgebra.transpose(J::BatchedJacobianOperator) = Transpose(J)
137+
59138
if VERSION >= v"1.11.0"
60139

61140
function tuple_of_vectors(M::Matrix{T}, shape) where {T}
@@ -66,49 +145,23 @@ if VERSION >= v"1.11.0"
66145
end
67146
end
68147

69-
function mul!(Out::AbstractMatrix, J::JacobianOperator, V::AbstractMatrix)
148+
function mul!(Out, J::BatchedJacobianOperator{N}, V) where {N}
70149
@assert size(Out, 2) == size(V, 2)
71150
out = tuple_of_vectors(Out, size(J.res))
72151
v = tuple_of_vectors(V, size(J.u))
73152

74-
N = length(out)
153+
@assert N == length(out)
75154
autodiff(
76155
Forward,
77-
maybe_duplicated(J.f, Val(N)), Const,
156+
maybe_duplicated(J.f, J.f′, Val(N)), Const,
78157
BatchDuplicated(J.res, out),
79158
BatchDuplicated(J.u, v),
80-
maybe_duplicated(J.p, Val(N))
159+
maybe_duplicated(J.p, J.p′, Val(N))
81160
)
82161
return nothing
83162
end
84163

85-
end # VERSION >= v"1.11.0"
86-
87-
LinearAlgebra.adjoint(J::JacobianOperator) = Adjoint(J)
88-
LinearAlgebra.transpose(J::JacobianOperator) = Transpose(J)
89-
90-
# Jᵀ(y, u) = ForwardDiff.gradient!(y, x -> dot(F(x), u), xk)
91-
# or just reverse mode
92-
93-
function mul!(out::AbstractVector, J′::Union{Adjoint{<:Any, <:JacobianOperator}, Transpose{<:Any, <:JacobianOperator}}, v::AbstractVector)
94-
J = parent(J′)
95-
# TODO: provide cache for `copy(v)`
96-
# Enzyme zeros input derivatives and that confuses the solvers.
97-
# If `out` is non-zero we might get spurious gradients
98-
fill!(out, 0)
99-
autodiff(
100-
Reverse,
101-
maybe_duplicated(J.f), Const,
102-
Duplicated(J.res, reshape(copy(v), size(J.res))),
103-
Duplicated(J.u, reshape(out, size(J.u))),
104-
maybe_duplicated(J.p)
105-
)
106-
return nothing
107-
end
108-
109-
if VERSION >= v"1.11.0"
110-
111-
function mul!(Out::AbstractMatrix, J′::Union{Adjoint{<:Any, <:JacobianOperator}, Transpose{<:Any, <:JacobianOperator}}, V::AbstractMatrix)
164+
function mul!(Out, J′::Union{Adjoint{<:Any, <:BatchedJacobianOperator{N}}, Transpose{<:Any, <:BatchedJacobianOperator{N}}}, V) where {N}
112165
J = parent(J′)
113166
@assert size(Out, 2) == size(V, 2)
114167

@@ -122,22 +175,20 @@ if VERSION >= v"1.11.0"
122175
out = tuple_of_vectors(Out, size(J.u))
123176
v = tuple_of_vectors(V, size(J.res))
124177

125-
N = length(out)
178+
@assert N == length(out)
126179

127-
# TODO: BatchDuplicated for J.f
128180
autodiff(
129181
Reverse,
130-
maybe_duplicated(J.f, Val(N)), Const,
182+
maybe_duplicated(J.f, J.f′, Val(N)), Const,
131183
BatchDuplicated(J.res, v),
132184
BatchDuplicated(J.u, out),
133-
maybe_duplicated(J.p, Val(N))
185+
maybe_duplicated(J.p, J.p′, Val(N))
134186
)
135187
return nothing
136188
end
137-
138189
end # VERSION >= v"1.11.0"
139190

140-
function Base.collect(JOp::Union{Adjoint{<:Any, <:JacobianOperator}, Transpose{<:Any, <:JacobianOperator}, JacobianOperator})
191+
function Base.collect(JOp::Union{Adjoint{<:Any, <:AbstractJacobianOperator}, Transpose{<:Any, <:AbstractJacobianOperator}, AbstractJacobianOperator})
141192
N, M = size(JOp)
142193
if JOp isa JacobianOperator
143194
v = zero(JOp.u)

test/runtests.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ let x₀ = [3.0, 5.0]
2222
@test stats.solved
2323
end
2424

25-
import Ariadne: JacobianOperator
25+
import Ariadne: JacobianOperator, BatchedJacobianOperator
2626
using Enzyme, LinearAlgebra
2727

2828
@testset "Jacobian" begin
@@ -55,13 +55,16 @@ using Enzyme, LinearAlgebra
5555

5656
# Batched
5757
if VERSION >= v"1.11.0"
58+
J = BatchedJacobianOperator{2}(F!, zeros(2), [3.0, 5.0], nothing)
59+
5860
V = [1.0 0.0; 0.0 1.0]
5961
Out = similar(V)
6062
mul!(Out, J, V)
6163

6264
@test Out == J_Enz
6365

6466
mul!(Out, transpose(J), V)
65-
@test Out == collect(transpose(J))
67+
@test Out == J_Enz'
68+
# @test Out == collect(transpose(J))
6669
end
6770
end

0 commit comments

Comments
 (0)