From 2e836ce2d466bff375970dcca08a5d0e91d95114 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Thu, 8 May 2025 19:53:50 +0530 Subject: [PATCH 1/9] chore: rm literal_getproperty(sol, ::Val{:u}) --- src/adjoint_common.jl | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/adjoint_common.jl b/src/adjoint_common.jl index 609d26528..b8c0e5ccb 100644 --- a/src/adjoint_common.jl +++ b/src/adjoint_common.jl @@ -749,16 +749,3 @@ function out_and_ts(_ts, duplicate_iterator_times, sol) return out, ts end -if !hasmethod(Zygote.adjoint, - Tuple{Zygote.AContext, typeof(Zygote.literal_getproperty), - SciMLBase.AbstractTimeseriesSolution, Val{:u}}) - Zygote.@adjoint function Zygote.literal_getproperty(sol::AbstractTimeseriesSolution, - ::Val{:u}) - function solu_adjoint(Δ) - zerou = zero(sol.prob.u0) - _Δ = @. ifelse(Δ === nothing, (zerou,), Δ) - (SciMLBase.build_solution(sol.prob, sol.alg, sol.t, _Δ),) - end - sol.u, solu_adjoint - end -end \ No newline at end of file From 2d50ea0215d4139448cc4612627dbd2f3d436ef8 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Fri, 9 May 2025 15:04:15 +0530 Subject: [PATCH 2/9] chore: unthunk tangents --- src/concrete_solve.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 2bc4d65be..30a16c8cd 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -1585,7 +1585,8 @@ function DiffEqBase._concrete_solve_adjoint( elseif eltype(ybar) <: AbstractArray Array(VectorOfArray(ybar)) elseif ybar isa Tangent - Array(VectorOfArray(ybar.u)) + yy = unthunk(ybar) + Array(VectorOfArray(unthunk.(unthunk(yy.u)))) else ybar end From 054fecf49155611de64b2bcd4fa070ef19344afa Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Fri, 9 May 2025 15:04:56 +0530 Subject: [PATCH 3/9] chore: unthunk tangents once --- src/concrete_solve.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 30a16c8cd..9c3577e5a 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -1586,7 +1586,7 @@ function DiffEqBase._concrete_solve_adjoint( Array(VectorOfArray(ybar)) elseif ybar isa Tangent yy = unthunk(ybar) - Array(VectorOfArray(unthunk.(unthunk(yy.u)))) + Array(VectorOfArray(unthunk.(yy.u))) else ybar end From b0649f5a7fb45e64e5734836327797875592b524 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Fri, 9 May 2025 16:58:07 +0530 Subject: [PATCH 4/9] chore: set tangent to .u --- src/concrete_solve.jl | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 9c3577e5a..58b5aac8f 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -545,12 +545,13 @@ function DiffEqBase._concrete_solve_adjoint( outtype = _out isa SubArray ? ArrayInterface.parameterless_type(_out.parent) : ArrayInterface.parameterless_type(_out) + Δu = Δ isa Tangent ? unthunk.(Δ.u) : Δ if only_end - eltype(Δ) <: NoTangent && return - if (Δ isa AbstractArray{<:AbstractArray} || Δ isa AbstractVectorOfArray) && - length(Δ) == 1 && i == 1 + eltype(Δu) <: NoTangent && return + if (Δu isa AbstractArray{<:AbstractArray} || Δu isa AbstractVectorOfArray) && + length(Δu) == 1 && i == 1 # user did sol[end] on only_end - x = Δ isa AbstractVectorOfArray ? Δ.u[1] : Δ[1] + x = Δu isa AbstractVectorOfArray ? Δu[1] : Δu[1] if _save_idxs isa Number vx = vec(x) _out[_save_idxs] .= vx[_save_idxs] @@ -563,12 +564,12 @@ function DiffEqBase._concrete_solve_adjoint( else Δ isa NoTangent && return if _save_idxs isa Number - x = vec(Δ) + x = vec(Δu) _out[_save_idxs] .= adapt(outtype, @view(x[_save_idxs])) elseif _save_idxs isa Colon - vec(_out) .= vec(adapt(outtype, Δ)) + vec(_out) .= vec(adapt(outtype, Δu)) else - x = vec(Δ) + x = vec(Δu) vec(@view(_out[_save_idxs])) .= adapt(outtype, @view(x[_save_idxs])) end end From cd0aded9e7d0d7f82cf7ce33a46623bea5795cfe Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Fri, 9 May 2025 17:02:45 +0530 Subject: [PATCH 5/9] chore: remove extra call to unthunk --- src/concrete_solve.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 58b5aac8f..0d0a29036 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -574,12 +574,11 @@ function DiffEqBase._concrete_solve_adjoint( end end else - Δu = Δ isa Tangent ? Δ.u : Δ !Base.isconcretetype(eltype(Δ)) && (Δu[i] isa NoTangent || eltype(Δu) <: NoTangent) && return if Δ isa AbstractArray{<:AbstractArray} || Δ isa AbstractVectorOfArray || Δ isa Tangent - x = (Δ isa AbstractVectorOfArray || Δ isa Tangent) ? Δ.u[i] : Δ[i] + x = (Δ isa AbstractVectorOfArray || Δ isa Tangent) ? Δu[i] : Δ[i] if _save_idxs isa Number _out[_save_idxs] = x[_save_idxs] elseif _save_idxs isa Colon From 9640bcf2880cd9c06b82316f4233a32c67155470 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Fri, 9 May 2025 17:10:27 +0530 Subject: [PATCH 6/9] chore: replace p with tunables in fallback for FwdSens --- src/concrete_solve.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 0d0a29036..a3c9aafaa 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -1017,7 +1017,7 @@ function DiffEqBase._concrete_solve_adjoint( ForwardDiff.value.(J'vec(v)) end else - zero(p) + zero(tunables) end end push!(pparts, vec(_dp)) From b031627358f65abafe22ffb5cf2956ff6a6c961c Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Mon, 12 May 2025 13:59:30 +0530 Subject: [PATCH 7/9] chore: tunables -> v in fallback for dp --- src/concrete_solve.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index a3c9aafaa..b624443cd 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -1017,7 +1017,7 @@ function DiffEqBase._concrete_solve_adjoint( ForwardDiff.value.(J'vec(v)) end else - zero(tunables) + zero(v) end end push!(pparts, vec(_dp)) From 1c4b37c119eb00b4dd2c7b7f721c13f2bda44a20 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Mon, 12 May 2025 19:48:14 +0530 Subject: [PATCH 8/9] chore: allocate zero tangents to zeros --- src/concrete_solve.jl | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index b624443cd..5af7bc0e8 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -1431,6 +1431,12 @@ function DiffEqBase._concrete_solve_adjoint( Array(ybar) # can also be a ODESolution elseif eltype(ybar) <: Number # CuArray{Floats} ybar + elseif ybar isa Tangent + ut = unthunk.(ybar.u) + ut_ = map(ut) do u + (u isa ZeroTangent || u isa NoTangent) ? zero(u0) : u + end + reduce(hcat, ut_) elseif ybar[1] isa Array return Array(ybar) else @@ -1586,7 +1592,10 @@ function DiffEqBase._concrete_solve_adjoint( Array(VectorOfArray(ybar)) elseif ybar isa Tangent yy = unthunk(ybar) - Array(VectorOfArray(unthunk.(yy.u))) + yy = map(unthunk.(yy.u)) do u + (u isa ZeroTangent || u isa NoTangent) ? zero(u0) : u + end + Array(VectorOfArray(yy)) else ybar end From fb3a20b62e4f599e863651416f7cbedf9dad4063 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Tue, 13 May 2025 16:47:07 +0530 Subject: [PATCH 9/9] chore: unthunk only u --- src/concrete_solve.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 5af7bc0e8..7ec4e8b07 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -1591,8 +1591,7 @@ function DiffEqBase._concrete_solve_adjoint( elseif eltype(ybar) <: AbstractArray Array(VectorOfArray(ybar)) elseif ybar isa Tangent - yy = unthunk(ybar) - yy = map(unthunk.(yy.u)) do u + yy = map(unthunk.(ybar.u)) do u (u isa ZeroTangent || u isa NoTangent) ? zero(u0) : u end Array(VectorOfArray(yy))