Skip to content

Commit b0649f5

Browse files
chore: set tangent to .u
1 parent 054fecf commit b0649f5

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

src/concrete_solve.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -545,12 +545,13 @@ function DiffEqBase._concrete_solve_adjoint(
545545
outtype = _out isa SubArray ?
546546
ArrayInterface.parameterless_type(_out.parent) :
547547
ArrayInterface.parameterless_type(_out)
548+
Δu = Δ isa Tangent ? unthunk.(Δ.u) : Δ
548549
if only_end
549-
eltype(Δ) <: NoTangent && return
550-
if (Δ isa AbstractArray{<:AbstractArray} || Δ isa AbstractVectorOfArray) &&
551-
length(Δ) == 1 && i == 1
550+
eltype(Δu) <: NoTangent && return
551+
if (Δu isa AbstractArray{<:AbstractArray} || Δu isa AbstractVectorOfArray) &&
552+
length(Δu) == 1 && i == 1
552553
# user did sol[end] on only_end
553-
x = Δ isa AbstractVectorOfArray ? Δ.u[1] : Δ[1]
554+
x = Δu isa AbstractVectorOfArray ? Δu[1] : Δu[1]
554555
if _save_idxs isa Number
555556
vx = vec(x)
556557
_out[_save_idxs] .= vx[_save_idxs]
@@ -563,12 +564,12 @@ function DiffEqBase._concrete_solve_adjoint(
563564
else
564565
Δ isa NoTangent && return
565566
if _save_idxs isa Number
566-
x = vec(Δ)
567+
x = vec(Δu)
567568
_out[_save_idxs] .= adapt(outtype, @view(x[_save_idxs]))
568569
elseif _save_idxs isa Colon
569-
vec(_out) .= vec(adapt(outtype, Δ))
570+
vec(_out) .= vec(adapt(outtype, Δu))
570571
else
571-
x = vec(Δ)
572+
x = vec(Δu)
572573
vec(@view(_out[_save_idxs])) .= adapt(outtype, @view(x[_save_idxs]))
573574
end
574575
end

0 commit comments

Comments
 (0)