@@ -545,12 +545,13 @@ function DiffEqBase._concrete_solve_adjoint(
545545 outtype = _out isa SubArray ?
546546 ArrayInterface. parameterless_type (_out. parent) :
547547 ArrayInterface. parameterless_type (_out)
548+ Δu = Δ isa Tangent ? unthunk .(Δ. u) : Δ
548549 if only_end
549- eltype (Δ ) <: NoTangent && return
550- if (Δ isa AbstractArray{<: AbstractArray } || Δ isa AbstractVectorOfArray) &&
551- length (Δ ) == 1 && i == 1
550+ eltype (Δu ) <: NoTangent && return
551+ if (Δu isa AbstractArray{<: AbstractArray } || Δu isa AbstractVectorOfArray) &&
552+ length (Δu ) == 1 && i == 1
552553 # user did sol[end] on only_end
553- x = Δ isa AbstractVectorOfArray ? Δ . u [1 ] : Δ [1 ]
554+ x = Δu isa AbstractVectorOfArray ? Δu [1 ] : Δu [1 ]
554555 if _save_idxs isa Number
555556 vx = vec (x)
556557 _out[_save_idxs] .= vx[_save_idxs]
@@ -563,22 +564,21 @@ function DiffEqBase._concrete_solve_adjoint(
563564 else
564565 Δ isa NoTangent && return
565566 if _save_idxs isa Number
566- x = vec (Δ )
567+ x = vec (Δu )
567568 _out[_save_idxs] .= adapt (outtype, @view (x[_save_idxs]))
568569 elseif _save_idxs isa Colon
569- vec (_out) .= vec (adapt (outtype, Δ ))
570+ vec (_out) .= vec (adapt (outtype, Δu ))
570571 else
571- x = vec (Δ )
572+ x = vec (Δu )
572573 vec (@view (_out[_save_idxs])) .= adapt (outtype, @view (x[_save_idxs]))
573574 end
574575 end
575576 else
576- Δu = Δ isa Tangent ? Δ. u : Δ
577577 ! Base. isconcretetype (eltype (Δ)) &&
578578 (Δu[i] isa NoTangent || eltype (Δu) <: NoTangent ) && return
579579 if Δ isa AbstractArray{<: AbstractArray } || Δ isa AbstractVectorOfArray ||
580580 Δ isa Tangent
581- x = (Δ isa AbstractVectorOfArray || Δ isa Tangent) ? Δ . u [i] : Δ[i]
581+ x = (Δ isa AbstractVectorOfArray || Δ isa Tangent) ? Δu [i] : Δ[i]
582582 if _save_idxs isa Number
583583 _out[_save_idxs] = x[_save_idxs]
584584 elseif _save_idxs isa Colon
@@ -1017,7 +1017,7 @@ function DiffEqBase._concrete_solve_adjoint(
10171017 ForwardDiff. value .(J' vec (v))
10181018 end
10191019 else
1020- zero (p )
1020+ zero (v )
10211021 end
10221022 end
10231023 push! (pparts, vec (_dp))
@@ -1431,6 +1431,12 @@ function DiffEqBase._concrete_solve_adjoint(
14311431 Array (ybar) # can also be a ODESolution
14321432 elseif eltype (ybar) <: Number # CuArray{Floats}
14331433 ybar
1434+ elseif ybar isa Tangent
1435+ ut = unthunk .(ybar. u)
1436+ ut_ = map (ut) do u
1437+ (u isa ZeroTangent || u isa NoTangent) ? zero (u0) : u
1438+ end
1439+ reduce (hcat, ut_)
14341440 elseif ybar[1 ] isa Array
14351441 return Array (ybar)
14361442 else
@@ -1585,7 +1591,10 @@ function DiffEqBase._concrete_solve_adjoint(
15851591 elseif eltype (ybar) <: AbstractArray
15861592 Array (VectorOfArray (ybar))
15871593 elseif ybar isa Tangent
1588- Array (VectorOfArray (ybar. u))
1594+ yy = map (unthunk .(ybar. u)) do u
1595+ (u isa ZeroTangent || u isa NoTangent) ? zero (u0) : u
1596+ end
1597+ Array (VectorOfArray (yy))
15891598 else
15901599 ybar
15911600 end
0 commit comments