Skip to content

Commit d3e44c4

Browse files
authored
DuplicatedNoNeed will provide incorrect results if values are read-write like for a cache (#14)
1 parent ee18b55 commit d3e44c4

File tree

2 files changed

+24
-6
lines changed

2 files changed

+24
-6
lines changed

src/NewtonKrylov.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import LinearAlgebra: mul!
1313

1414
function maybe_duplicated(f, df)
1515
if !Enzyme.Compiler.guaranteed_const(typeof(f))
16-
return DuplicatedNoNeed(f, df)
16+
return Duplicated(f, df)
1717
else
1818
return Const(f)
1919
end
@@ -37,15 +37,16 @@ end
3737

3838
Base.size(J::JacobianOperator) = (length(J.res), length(J.u))
3939
Base.eltype(J::JacobianOperator) = eltype(J.u)
40+
Base.length(J::JacobianOperator) = prod(size(J))
4041

4142
function mul!(out, J::JacobianOperator, v)
4243
# Enzyme.make_zero!(J.f_cache)
4344
f_cache = Enzyme.make_zero(J.f) # Stop gap until we can zero out mutable values
4445
autodiff(
4546
Forward,
4647
maybe_duplicated(J.f, f_cache), Const,
47-
DuplicatedNoNeed(J.res, reshape(out, size(J.res))),
48-
DuplicatedNoNeed(J.u, reshape(v, size(J.u)))
48+
Duplicated(J.res, reshape(out, size(J.res))),
49+
Duplicated(J.u, reshape(v, size(J.u)))
4950
)
5051
return nothing
5152
end
@@ -61,16 +62,18 @@ function mul!(out, J′::Union{Adjoint{<:Any, <:JacobianOperator}, Transpose{<:A
6162
Enzyme.make_zero!(J.f_cache)
6263
# TODO: provide cache for `copy(v)`
6364
# Enzyme zeros input derivatives and that confuses the solvers.
65+
# If `out` is non-zero we might get spurious gradients
66+
fill!(out, 0)
6467
autodiff(
6568
Reverse,
6669
maybe_duplicated(J.f, J.f_cache), Const,
67-
DuplicatedNoNeed(J.res, reshape(copy(v), size(J.res))),
68-
DuplicatedNoNeed(J.u, reshape(out, size(J.u)))
70+
Duplicated(J.res, reshape(copy(v), size(J.res))),
71+
Duplicated(J.u, reshape(out, size(J.u)))
6972
)
7073
return nothing
7174
end
7275

73-
function Base.collect(JOp::JacobianOperator)
76+
function Base.collect(JOp::Union{Adjoint{<:Any, <:JacobianOperator}, Transpose{<:Any, <:JacobianOperator}, JacobianOperator})
7477
N, M = size(JOp)
7578
v = zeros(eltype(JOp), M)
7679
out = zeros(eltype(JOp), N)

test/runtests.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,19 @@ using Enzyme, LinearAlgebra
2828
@testset "Jacobian" begin
2929
J_Enz = jacobian(Forward, F, [3.0, 5.0]) |> only
3030
J = JacobianOperator(F!, zeros(2), [3.0, 5.0])
31+
32+
@test size(J) == (2, 2)
33+
@test length(J) == 4
34+
@test eltype(J) == Float64
35+
36+
out = [NaN, NaN]
37+
mul!(out, J, [1.0, 0.0])
38+
@test out == [6.0, 7.38905609893065]
39+
40+
out = [NaN, NaN]
41+
mul!(out, transpose(J), [1.0, 0.0])
42+
@test out == [6.0, 10.0]
43+
3144
J_NK = collect(J)
3245

3346
@test J_NK == J_Enz
@@ -37,4 +50,6 @@ using Enzyme, LinearAlgebra
3750
mul!(out, J, v)
3851

3952
@test out J_Enz * v
53+
54+
@test collect(transpose(J)) == transpose(collect(J))
4055
end

0 commit comments

Comments
 (0)