Skip to content

Commit c4178b1

Browse files
chore: treat prob as immutable
1 parent e3b0108 commit c4178b1

File tree

6 files changed

+24
-13
lines changed

6 files changed

+24
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
3131
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
3232
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3333
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
34+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3435

3536
[weakdeps]
3637
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
@@ -41,7 +42,6 @@ PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
4142
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
4243
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
4344
RCall = "6f49c342-dc21-5d91-9882-a32aef131414"
44-
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4545

4646
[extensions]
4747
SciMLBaseChainRulesCoreExt = "ChainRulesCore"

ext/SciMLBaseChainRulesCoreExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ end
5858

5959
function ChainRulesCore.rrule(::Type{ODEProblem}, args...; kwargs...)
6060
function ODEProblemAdjoint(ȳ)
61+
@show "some con"
6162
(NoTangent(), ȳ.f, ȳ.u0, ȳ.tspan, ȳ.p, ȳ.kwargs, ȳ.problem_type)
6263
end
6364

ext/SciMLBaseZygoteExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ end
300300
∇responsible_map(__context__, f, args...)
301301
end
302302

303-
@_adjoint_keepthunks function Zygote.literal_getfield(x::ODEProblem, ::Val{f}) where f
303+
@_adjoint_keepthunks function Zygote.literal_getfield(x::SciMLBase.AbstractSciMLProblem, ::Val{f}) where f
304304
val = getfield(x, f)
305305
function back(Δ)
306306
Zygote.accum_param(__context__, val, Δ) === nothing && return
@@ -311,7 +311,7 @@ end
311311
else
312312
dx = Zygote.grad_mut(__context__, x)
313313
dx[] = (; dx[]..., pair(Val(f), Zygote.accum(getfield(dx[], f), Δ))...)
314-
return (dx,nothing)
314+
return (dx[],nothing)
315315
end
316316
end
317317
Zygote.unwrap(val), back

src/SciMLBase.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ import SciMLOperators:
4444

4545
@reexport using SciMLOperators
4646

47+
using Zygote
48+
4749
function __solve end
4850
function __init end
4951

src/initialization.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,9 +260,8 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
260260
end
261261
nlsol, success = solve_initialization(initdata, initprob, alg; reltol, abstol, nlsolve_alg )
262262

263-
nlsol2 = prob.f.initialization_data.initializeprob
264263
if initdata.initializeprobmap !== nothing
265-
u02 = initdata.initializeprobmap(nlsol2)
264+
u02 = initdata.initializeprobmap(nlsol)
266265
end
267266
if initdata.initializeprobpmap !== nothing
268267
p2 = initdata.initializeprobpmap(valp, nlsol)

src/remake.jl

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,8 @@ function remake(prob::NonlinearProblem;
703703
if problem_type === missing
704704
problem_type = prob.problem_type
705705
end
706+
# error()
707+
# @show f
706708

707709
prob = if kwargs === missing
708710
NonlinearProblem{isinplace(prob)}(f = f, u0 = newu0, p = newp,
@@ -1206,18 +1208,25 @@ function maybe_eager_initialize_problem(prob::AbstractSciMLProblem, initializati
12061208
if lazy_initialization === nothing
12071209
lazy_initialization = !is_trivial_initialization(initialization_data)
12081210
end
1209-
if initialization_data !== nothing && !lazy_initialization &&
1210-
(!is_time_dependent(prob) || current_time(prob) !== nothing)
1211+
cond = initialization_data !== nothing && !lazy_initialization &&
1212+
(!is_time_dependent(prob) || current_time(prob) !== nothing)
1213+
@show cond
1214+
if cond
1215+
# @show "in maybe_eager_initialize_problem"
12111216
u0, p, _ = get_initial_values(
12121217
prob, prob, prob.f, OverrideInit(), Val(isinplace(prob)))
1213-
if u0 !== nothing && eltype(u0) == Any && isempty(u0)
1214-
u0 = nothing
1215-
end
1218+
# if u0 !== nothing && eltype(u0) == Any && isempty(u0)
1219+
# u0 = nothing
1220+
# end
12161221
else
1217-
u0 = state_values(prob)
1218-
p = parameter_values(prob)
1222+
u02 = state_values(prob)
1223+
p2 = parameter_values(prob)
12191224
end
1220-
return u0, p
1225+
# @show p
1226+
1227+
u03 = cond ? u0 : u02
1228+
p3 = cond ? p : p2
1229+
return u03, p3
12211230
end
12221231

12231232
function remake(thing::AbstractJumpProblem; kwargs...)

0 commit comments

Comments
 (0)