@@ -545,12 +545,13 @@ function DiffEqBase._concrete_solve_adjoint(
545545 outtype = _out isa SubArray ?
546546 ArrayInterface. parameterless_type (_out. parent) :
547547 ArrayInterface. parameterless_type (_out)
548+ Δu = Δ isa Tangent ? unthunk .(Δ. u) : Δ
548549 if only_end
549- eltype (Δ ) <: NoTangent && return
550- if (Δ isa AbstractArray{<: AbstractArray } || Δ isa AbstractVectorOfArray) &&
551- length (Δ ) == 1 && i == 1
550+ eltype (Δu ) <: NoTangent && return
551+ if (Δu isa AbstractArray{<: AbstractArray } || Δu isa AbstractVectorOfArray) &&
552+ length (Δu ) == 1 && i == 1
552553 # user did sol[end] on only_end
553- x = Δ isa AbstractVectorOfArray ? Δ . u [1 ] : Δ [1 ]
554+ x = Δu isa AbstractVectorOfArray ? Δu [1 ] : Δu [1 ]
554555 if _save_idxs isa Number
555556 vx = vec (x)
556557 _out[_save_idxs] .= vx[_save_idxs]
@@ -563,12 +564,12 @@ function DiffEqBase._concrete_solve_adjoint(
563564 else
564565 Δ isa NoTangent && return
565566 if _save_idxs isa Number
566- x = vec (Δ )
567+ x = vec (Δu )
567568 _out[_save_idxs] .= adapt (outtype, @view (x[_save_idxs]))
568569 elseif _save_idxs isa Colon
569- vec (_out) .= vec (adapt (outtype, Δ ))
570+ vec (_out) .= vec (adapt (outtype, Δu ))
570571 else
571- x = vec (Δ )
572+ x = vec (Δu )
572573 vec (@view (_out[_save_idxs])) .= adapt (outtype, @view (x[_save_idxs]))
573574 end
574575 end
0 commit comments