Skip to content

Commit 9dba82e

Browse files
committed
Handle potential ND data
1 parent 1a9bc39 commit 9dba82e

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

src/NewtonKrylov.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ using Enzyme
1212
import LinearAlgebra: mul!
1313

1414
function maybe_duplicated(f, df)
15-
if Enzyme.Compiler.active_reg(typeof(f))
15+
if !Enzyme.Compiler.guaranteed_const(typeof(f))
1616
return DuplicatedNoNeed(f, df)
1717
else
1818
return Const(f)
@@ -34,11 +34,13 @@ Base.size(J::JacobianOperator) = (length(J.res), length(J.u))
3434
Base.eltype(J::JacobianOperator) = eltype(J.u)
3535

3636
function mul!(out, J::JacobianOperator, v)
37-
Enzyme.make_zero!(J.f_cache)
37+
# Enzyme.make_zero!(J.f_cache)
38+
f_cache = Enzyme.make_zero(J.f) # Stop gap until we can zero out mutable values
3839
autodiff(
3940
Forward,
40-
maybe_duplicated(J.f, J.f_cache), Const,
41-
DuplicatedNoNeed(J.res, out), DuplicatedNoNeed(J.u, v)
41+
maybe_duplicated(J.f, f_cache), Const,
42+
DuplicatedNoNeed(J.res, reshape(out, size(J.res))),
43+
DuplicatedNoNeed(J.u, reshape(v, size(J.u)))
4244
)
4345
return nothing
4446
end
@@ -57,7 +59,8 @@ function mul!(out, J′::Union{Adjoint{<:Any, <:JacobianOperator}, Transpose{<:A
5759
autodiff(
5860
Reverse,
5961
maybe_duplicated(J.f, J.f_cache), Const,
60-
DuplicatedNoNeed(J.res, copy(v)), DuplicatedNoNeed(J.u, out)
62+
DuplicatedNoNeed(J.res, reshape(copy(v), size(J.res))),
63+
DuplicatedNoNeed(J.u, reshape(out, size(J.u)))
6164
)
6265
return nothing
6366
end

0 commit comments

Comments
 (0)