Skip to content

Commit 2b37c4d

Browse files
fix downstream
1 parent 136475e commit 2b37c4d

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

src/solve.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,8 @@ function get_concrete_problem(prob; kwargs...)
132132
u0_promote = promote_u0(u0, p, tspan[1])
133133
tspan_promote = promote_tspan(u0, p, tspan, prob, kwargs)
134134
if isconcreteu0(prob, tspan[1], kwargs) && typeof(u0_promote) === typeof(u0) &&
135-
prob.tspan == tspan && typeof(tspan) === typeof(tspan_promote)
135+
prob.tspan == tspan && typeof(tspan) === typeof(tspan_promote) &&
136+
p == prob.p
136137
return prob
137138
else
138139
return remake(prob; u0 = u0_promote, p=p, tspan = tspan_promote)

src/zygote.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,16 @@ ZygoteRules.@adjoint numargs(f) = (numargs(f),df->(nothing,))
7575
ChainRulesCore.rrule(::typeof(numargs),f) = (numargs(f),df->(nothing,))
7676

7777
# Until https://github.com/FluxML/Zygote.jl/issues/664 is fixed
78-
ZygoteRules.@adjoint function Base.pairs(x::NamedTuple)
79-
Base.pairs(x), Δ ->.data,)
80-
end
78+
ZygoteRules.@adjoint function Base.pairs(x::T) where T
79+
y = Base.pairs(x)
80+
back(dx::NamedTuple) = (dx.data,)
81+
function back(dx::Dict)
82+
T <: AbstractDict && return (dx,)
83+
z = zero(x)
84+
for (k,v) in dx
85+
z[k] = v
86+
end
87+
(z,)
88+
end
89+
y, back
90+
end

0 commit comments

Comments
 (0)