Skip to content

Commit 172a6bc

Browse files
fix: fix definition of T and N in ChainRulesCore adjoint
1 parent 3845403 commit 172a6bc

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

ext/SciMLBaseChainRulesCoreExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@ function ChainRulesCore.rrule(
2525
dp = zero_tangent(parameter_values(VA.prob))
2626
end
2727
dprob = remake(VA.prob, p = dp)
28-
T = eltype(eltype(VA.u))
29-
N = length(VA.prob.p)
3028
du, dprob
3129
else
3230
du = [m == j ? [i == k ? Δ : zero(VA.u[1][1]) for k in 1:length(VA.u[1])] :
@@ -35,6 +33,8 @@ function ChainRulesCore.rrule(
3533
dprob = remake(VA.prob, p = dp)
3634
du, dprob
3735
end
36+
T = eltype(eltype(du))
37+
N = ndims(eltype(du)) + 1
3838
Δ′ = ODESolution{T, N}(du, nothing, nothing, VA.t, VA.k, nothing, dprob,
3939
VA.alg, VA.interp, VA.dense, 0, VA.stats, VA.alg_choice, VA.retcode)
4040
(NoTangent(), Δ′, NoTangent(), NoTangent())

0 commit comments

Comments
 (0)