Skip to content

Commit 6c3a328

Browse files
Merge pull request #1196 from SciML/dg/rm_getp
Remove another `literal_getproperty(sol, ::Val{:u})` dispatch
2 parents 9dfbfdc + fb3a20b commit 6c3a328

File tree

2 files changed

+20
-24
lines changed

2 files changed

+20
-24
lines changed

src/adjoint_common.jl

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -748,16 +748,3 @@ function out_and_ts(_ts, duplicate_iterator_times, sol)
748748
return out, ts
749749
end
750750

751-
if !hasmethod(Zygote.adjoint,
752-
Tuple{Zygote.AContext, typeof(Zygote.literal_getproperty),
753-
SciMLBase.AbstractTimeseriesSolution, Val{:u}})
754-
Zygote.@adjoint function Zygote.literal_getproperty(sol::AbstractTimeseriesSolution,
755-
::Val{:u})
756-
function solu_adjoint(Δ)
757-
zerou = zero(sol.prob.u0)
758-
= @. ifelse=== nothing, (zerou,), Δ)
759-
(SciMLBase.build_solution(sol.prob, sol.alg, sol.t, _Δ),)
760-
end
761-
sol.u, solu_adjoint
762-
end
763-
end

src/concrete_solve.jl

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -545,12 +545,13 @@ function DiffEqBase._concrete_solve_adjoint(
545545
outtype = _out isa SubArray ?
546546
ArrayInterface.parameterless_type(_out.parent) :
547547
ArrayInterface.parameterless_type(_out)
548+
Δu = Δ isa Tangent ? unthunk.(Δ.u) : Δ
548549
if only_end
549-
eltype(Δ) <: NoTangent && return
550-
if (Δ isa AbstractArray{<:AbstractArray} || Δ isa AbstractVectorOfArray) &&
551-
length(Δ) == 1 && i == 1
550+
eltype(Δu) <: NoTangent && return
551+
if (Δu isa AbstractArray{<:AbstractArray} || Δu isa AbstractVectorOfArray) &&
552+
length(Δu) == 1 && i == 1
552553
# user did sol[end] on only_end
553-
x = Δ isa AbstractVectorOfArray ? Δ.u[1] : Δ[1]
554+
x = Δu isa AbstractVectorOfArray ? Δu[1] : Δu[1]
554555
if _save_idxs isa Number
555556
vx = vec(x)
556557
_out[_save_idxs] .= vx[_save_idxs]
@@ -563,22 +564,21 @@ function DiffEqBase._concrete_solve_adjoint(
563564
else
564565
Δ isa NoTangent && return
565566
if _save_idxs isa Number
566-
x = vec(Δ)
567+
x = vec(Δu)
567568
_out[_save_idxs] .= adapt(outtype, @view(x[_save_idxs]))
568569
elseif _save_idxs isa Colon
569-
vec(_out) .= vec(adapt(outtype, Δ))
570+
vec(_out) .= vec(adapt(outtype, Δu))
570571
else
571-
x = vec(Δ)
572+
x = vec(Δu)
572573
vec(@view(_out[_save_idxs])) .= adapt(outtype, @view(x[_save_idxs]))
573574
end
574575
end
575576
else
576-
Δu = Δ isa Tangent ? Δ.u : Δ
577577
!Base.isconcretetype(eltype(Δ)) &&
578578
(Δu[i] isa NoTangent || eltype(Δu) <: NoTangent) && return
579579
if Δ isa AbstractArray{<:AbstractArray} || Δ isa AbstractVectorOfArray ||
580580
Δ isa Tangent
581-
x =isa AbstractVectorOfArray || Δ isa Tangent) ? Δ.u[i] : Δ[i]
581+
x =isa AbstractVectorOfArray || Δ isa Tangent) ? Δu[i] : Δ[i]
582582
if _save_idxs isa Number
583583
_out[_save_idxs] = x[_save_idxs]
584584
elseif _save_idxs isa Colon
@@ -1017,7 +1017,7 @@ function DiffEqBase._concrete_solve_adjoint(
10171017
ForwardDiff.value.(J'vec(v))
10181018
end
10191019
else
1020-
zero(p)
1020+
zero(v)
10211021
end
10221022
end
10231023
push!(pparts, vec(_dp))
@@ -1431,6 +1431,12 @@ function DiffEqBase._concrete_solve_adjoint(
14311431
Array(ybar) # can also be a ODESolution
14321432
elseif eltype(ybar) <: Number # CuArray{Floats}
14331433
ybar
1434+
elseif ybar isa Tangent
1435+
ut = unthunk.(ybar.u)
1436+
ut_ = map(ut) do u
1437+
(u isa ZeroTangent || u isa NoTangent) ? zero(u0) : u
1438+
end
1439+
reduce(hcat, ut_)
14341440
elseif ybar[1] isa Array
14351441
return Array(ybar)
14361442
else
@@ -1585,7 +1591,10 @@ function DiffEqBase._concrete_solve_adjoint(
15851591
elseif eltype(ybar) <: AbstractArray
15861592
Array(VectorOfArray(ybar))
15871593
elseif ybar isa Tangent
1588-
Array(VectorOfArray(ybar.u))
1594+
yy = map(unthunk.(ybar.u)) do u
1595+
(u isa ZeroTangent || u isa NoTangent) ? zero(u0) : u
1596+
end
1597+
Array(VectorOfArray(yy))
15891598
else
15901599
ybar
15911600
end

0 commit comments

Comments
 (0)