Skip to content

Commit 9789034

Browse files
chore: revert bad commit
1 parent 6e549e7 commit 9789034

File tree

5 files changed

+24
-44
lines changed

5 files changed

+24
-44
lines changed

src/adjoint_common.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -749,16 +749,16 @@ function out_and_ts(_ts, duplicate_iterator_times, sol)
749749
return out, ts
750750
end
751751

752-
# if !hasmethod(Zygote.adjoint,
753-
# Tuple{Zygote.AContext, typeof(Zygote.literal_getproperty),
754-
# SciMLBase.AbstractTimeseriesSolution, Val{:u}})
755-
# Zygote.@adjoint function Zygote.literal_getproperty(sol::AbstractTimeseriesSolution,
756-
# ::Val{:u})
757-
# function solu_adjoint(Δ)
758-
# zerou = zero(sol.prob.u0)
759-
# _Δ = @. ifelse(Δ === nothing, (zerou,), Δ)
760-
# (SciMLBase.build_solution(sol.prob, sol.alg, sol.t, _Δ),)
761-
# end
762-
# sol.u, solu_adjoint
763-
# end
764-
# end
752+
if !hasmethod(Zygote.adjoint,
753+
Tuple{Zygote.AContext, typeof(Zygote.literal_getproperty),
754+
SciMLBase.AbstractTimeseriesSolution, Val{:u}})
755+
Zygote.@adjoint function Zygote.literal_getproperty(sol::AbstractTimeseriesSolution,
756+
::Val{:u})
757+
function solu_adjoint(Δ)
758+
zerou = zero(sol.prob.u0)
759+
= @. ifelse=== nothing, (zerou,), Δ)
760+
(SciMLBase.build_solution(sol.prob, sol.alg, sol.t, _Δ),)
761+
end
762+
sol.u, solu_adjoint
763+
end
764+
end

src/gauss_adjoint.jl

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -488,19 +488,8 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand)
488488
ReverseDiff.reverse_pass!(tape)
489489
copyto!(vec(out), ReverseDiff.deriv(tp))
490490
elseif sensealg.autojacvec isa ZygoteVJP
491-
# global gf = f
492-
# global gy = y
493-
# global gtunables = tunables
494-
# global grepack = repack
495-
# global gt = t
496-
# @show f(y, tunables, t)
497-
c = Zygote.bufferfrom(y)
498-
c = copy(y)
499-
# @show f(c, y, repack(tunables), t)
500491
_dy, back = Zygote.pullback(tunables) do tunables
501-
c = Zygote.bufferfrom(y)
502-
f(c, y, repack(tunables), t)
503-
vec(copy(c))
492+
vec(f(y, repack(tunables), t))
504493
end
505494
tmp = back(λ)
506495
if tmp[1] === nothing

src/quadrature_adjoint.jl

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,6 @@ function AdjointSensitivityIntegrand(sol, adj_sol, sensealg, dgdp = nothing)
176176
adj_prob = adj_sol.prob
177177
(; f, tspan) = prob
178178
p = parameter_values(prob)
179-
tunables, _, _ = canonicalize(Tunable(), p)
180179
u0 = state_values(prob)
181180
numparams = length(p)
182181
y = zero(state_values(prob))
@@ -253,9 +252,6 @@ function AdjointSensitivityIntegrand(sol, adj_sol, sensealg, dgdp = nothing)
253252
end
254253
AdjointSensitivityIntegrand(sol, adj_sol, p, y, λ, pf, f_cache, pJ, paramjac_config,
255254
sensealg, dgdp_cache, dgdp)
256-
257-
# AdjointSensitivityIntegrand(sol, adj_sol, tunables, y, λ, pf, f_cache, pJ, paramjac_config,
258-
# sensealg, dgdp_cache, dgdp)
259255
end
260256

261257
# out = λ df(u, p, t)/dp at u=y, p=p, t=t
@@ -289,15 +285,13 @@ function vec_pjac!(out, λ, y, t, S::AdjointSensitivityIntegrand)
289285
copyto!(vec(out), ReverseDiff.deriv(tp))
290286
elseif sensealg.autojacvec isa ZygoteVJP
291287
_dy, back = Zygote.pullback(p) do p
292-
# @show f(y, p, t)
293288
vec(f(y, p, t))
294289
end
295290
tmp = back(λ)
296-
# @show tmp
297291
if tmp[1] === nothing
298292
out[:] .= 0
299293
else
300-
out[:] .= vec(tmp[1].tunable)
294+
out[:] .= vec(tmp[1])
301295
end
302296
elseif sensealg.autojacvec isa MooncakeVJP
303297
_, _, p_grad = mooncake_run_ad(paramjac_config, y, p, t, λ)
@@ -336,7 +330,6 @@ function (S::AdjointSensitivityIntegrand)(out, t)
336330
end
337331

338332
function (S::AdjointSensitivityIntegrand)(t)
339-
# out = similar(S.p.tunable)
340333
out = similar(S.p)
341334
out .= false
342335
S(out, t)
@@ -366,9 +359,7 @@ function _adjoint_sensitivities(sol, sensealg::QuadratureAdjoint, alg; t = nothi
366359
res, err = quadgk(integrand, sol.prob.tspan[1], sol.prob.tspan[2],
367360
atol = abstol, rtol = reltol)
368361
else
369-
tunables, _, _ = canonicalize(Tunable(), integrand.p)
370-
# res = zero(integrand.p)'
371-
res = zero(tunables)'
362+
res = zero(integrand.p)'
372363

373364
# handle discrete dgdp contributions
374365
if dgdp_discrete !== nothing
@@ -530,4 +521,4 @@ function _update_integrand_and_dgrad(res, sensealg::QuadratureAdjoint, cb, integ
530521
vecjacobian!(dλ, integrand.y, dλ, integrand.p, t, fakeS; dgrad = dgrad)
531522
res .-= dgrad
532523
return integrand
533-
end
524+
end

test/parameter_handling.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ x = ones(Float32, 2, 3)
1212

1313
nlprob(u, p) = first(model_nls(u, p, st_nls)) .- u
1414

15-
prob = NonlinearProblem(nlprob, zeros(2, 3), ps)
15+
prob = NonlinearProblem(nlprob, zeros(2, 3), ca)
1616

1717
@test_nowarn solve(prob, NewtonRaphson())
1818

test/runtests.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,13 @@ end
8181

8282
if GROUP == "All" || GROUP == "Core6"
8383
@testset "Core 6" begin
84-
# @time @safetestset "Enzyme Closures" include("enzyme_closure.jl")
85-
# @time @safetestset "Complex Matrix FiniteDiff Adjoint" include("complex_matrix_finitediff.jl")
86-
# @time @safetestset "Null Parameters" include("null_parameters.jl")
87-
# @time @safetestset "Forward Mode Prob Kwargs" include("forward_prob_kwargs.jl")
84+
@time @safetestset "Enzyme Closures" include("enzyme_closure.jl")
85+
@time @safetestset "Complex Matrix FiniteDiff Adjoint" include("complex_matrix_finitediff.jl")
86+
@time @safetestset "Null Parameters" include("null_parameters.jl")
87+
@time @safetestset "Forward Mode Prob Kwargs" include("forward_prob_kwargs.jl")
8888
@time @safetestset "Steady State Adjoint" include("steady_state.jl")
89-
# @time @safetestset "Concrete Solve Derivatives of Second Order ODEs" include("second_order_odes.jl")
90-
# @time @safetestset "Parameter Compatibility Errors" include("parameter_compatibility_errors.jl")
89+
@time @safetestset "Concrete Solve Derivatives of Second Order ODEs" include("second_order_odes.jl")
90+
@time @safetestset "Parameter Compatibility Errors" include("parameter_compatibility_errors.jl")
9191
end
9292
end
9393

0 commit comments

Comments
 (0)