Skip to content

Commit 4fe0a3a

Browse files
authored
Merge pull request #191 from JuliaDynamics/hw/initbounds
add variable transformation to respect bounds in initialization problem
2 parents 5c0e97b + ef35144 commit 4fe0a3a

File tree

8 files changed

+434
-69
lines changed

8 files changed

+434
-69
lines changed

docs/src/API.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ set_bounds!
128128
find_fixpoint
129129
initialize_component!
130130
init_residual
131+
get_initial_state
132+
dump_initial_state
131133
```
132134

133135
## Execution Types

docs/src/initialization.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,8 @@ As a quick test we can ensure that the angle indeed matches the voltag angel:
6767
```@example compinit
6868
get_init(vf, :θ) ≈ atan(get_default(vf, :u_i), get_default(vf, :u_r))
6969
```
70+
71+
It is possible to inspect initial states (also for observed symbols) using [`get_initial_state`](@ref). You can print out the whole state using [`dump_initial_state`](@ref).
72+
```@example compinit
73+
dump_initial_state(vf)
74+
```

ext/MTKExt.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,14 +262,13 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
262262
deepcopy(_sys)
263263
else
264264
_openinputs = setdiff(allinputs, Set(full_parameters(_sys)))
265-
get_variables.(full_equations(_sys))
266265
all_eq_vars = mapreduce(get_variables, union, full_equations(_sys), init=Set{Symbolic}())
267266
if !(_openinputs all_eq_vars)
268267
missing_inputs = setdiff(_openinputs, all_eq_vars)
269268
@warn "The specified inputs ($missing_inputs) do not appear in the equations of the system!"
270269
_openinputs = setdiff(_openinputs, missing_inputs)
271270
end
272-
structural_simplify(_sys, (_openinputs, alloutputs); simplify=true)[1]
271+
structural_simplify(_sys, (_openinputs, alloutputs); simplify=false)[1]
273272
end
274273

275274
states = unknowns(sys)
@@ -334,6 +333,10 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
334333
throw(ArgumentError("Output $out was neither foundin states nor in observed equations."))
335334
end
336335
eq = obseqs[idx]
336+
if !isempty(rhs_differentials(eq))
337+
println(obs_subs[out])
338+
throw(ArgumentError("Algebraic FF equation for output $out contains differentials in the RHS: $(rhs_differentials(eq))"))
339+
end
337340
deleteat!(obseqs, idx)
338341

339342
if ff_to_constraint && !isempty(get_variables(eq.rhs) allinputs)
@@ -417,7 +420,7 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
417420
obsf = obsf_ip,
418421
equations=formulas,
419422
outputeqs=Dict(Iterators.flatten(outputss) .=> gformulas),
420-
observed=Dict(obsstates .=> obsformulas),
423+
observed=Dict(getname.(obsstates) .=> obsformulas),
421424
params)
422425
end
423426

src/NetworkDynamics.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ export has_guess, get_guess, set_guess!
6969
export has_init, get_init, set_init!
7070
export has_bounds, get_bounds, set_bounds!
7171
export has_graphelement, get_graphelement, set_graphelement!
72+
export get_initial_state, dump_initial_state
7273
include("metadata.jl")
7374

7475
using NonlinearSolve: AbstractNonlinearSolveAlgorithm, NonlinearFunction

src/initialization.jl

Lines changed: 141 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ function _solve_fixpoint(prob, alg::SteadyStateDiffEqAlgorithm; kwargs...)
4848
sol = SciMLBase.solve(prob, alg; kwargs...)
4949
end
5050

51-
function initialization_problem(cf::T; t=NaN, verbose=true) where {T<:ComponentModel}
51+
function initialization_problem(cf::T; t=NaN, apply_bound_transformation, verbose=true) where {T<:ComponentModel}
5252
hasinsym(cf) || throw(ArgumentError("Component model musst have `insym`!"))
5353

5454
outfree_ms = Tuple((!).(map(s -> has_default(cf, s), sv)) for sv in outsym_normalized(cf))
@@ -92,6 +92,48 @@ function initialization_problem(cf::T; t=NaN, verbose=true) where {T<:ComponentM
9292
unl_range_p = nextrange(pfree_m)
9393
@assert vcat(unl_range_outs..., unl_range_u, unl_range_ins..., unl_range_p) == 1:Nfree
9494

95+
# check for positivity and negativity constraints
96+
bounds = map(freesym) do sym
97+
if has_bounds(cf, sym)
98+
bound = get_bounds(cf, sym)
99+
if bound[1] >= 0 && bound[2] > bound[1]
100+
return :pos
101+
elseif bound[1] < bound[2] && bound[2] <= 0
102+
return :neg
103+
end
104+
end
105+
:none
106+
end
107+
if !apply_bound_transformation || all(isequal(:none), bounds)
108+
boundT! = identity
109+
inv_boundT! = identity
110+
else
111+
if verbose
112+
idxs = findall(!isequal(:none), bounds)
113+
@info "Apply positivity/negativity conserving variable transformation on $(freesym[idxs]) to satisfy bounds."
114+
end
115+
boundT! = (u) -> begin
116+
for i in eachindex(u, bounds)
117+
if bounds[i] == :pos
118+
u[i] = u[i]^2
119+
elseif bounds[i] == :neg
120+
u[i] = -u[i]^2
121+
end
122+
end
123+
return u
124+
end
125+
inv_boundT! = (u) -> begin
126+
for i in eachindex(u, bounds)
127+
if bounds[i] == :pos
128+
u[i] = sqrt(u[i])
129+
elseif bounds[i] == :neg
130+
u[i] = sqrt(-u[i])
131+
end
132+
end
133+
return u
134+
end
135+
end
136+
95137
# check for missin guesses
96138
missing_guesses = Symbol[]
97139
uguess = map(freesym) do s
@@ -101,67 +143,68 @@ function initialization_problem(cf::T; t=NaN, verbose=true) where {T<:ComponentM
101143
push!(missing_guesses, s)
102144
end
103145
end
146+
isempty(missing_guesses) || throw(ArgumentError("Missing guesses for free variables $(missing_guesses)"))
104147

105-
for s in freesym
106-
if has_bounds(cf, s)
107-
@warn "Ignore bounds $(get_bounds(cf, s)) for $s. Not supported yet."
108-
end
148+
# apply bound conserving transformation to initial state
149+
try
150+
inv_boundT!(uguess)
151+
catch
152+
throw(ArgumentError("Initial guess violates bounds. Check the docstring on `NetworkDynamics.initialize_component!`\
153+
about bound satisfying transformations!"))
109154
end
110155

111-
isempty(missing_guesses) || throw(ArgumentError("Missing guesses for free variables $(missing_guesses)"))
112-
113156
N = ForwardDiff.pickchunksize(Nfree)
114-
fz = let fg = compfg(cf),
115-
outcaches=map(d->DiffCache(zeros(d), N), outdim_normalized(cf)),
116-
ucache=DiffCache(zeros(dim(cf)), N),
117-
incaches=map(d->DiffCache(zeros(d), N), indim_normalized(cf)),
118-
pcache=DiffCache(zeros(pdim(cf))),
119-
t=t,
120-
outfree_ms=outfree_ms, ufree_m=ufree_m, infree_ms=infree_ms, pfree_m=pfree_m,
121-
outfixs=outfixs, ufix=ufix, infixs=infixs, pfix=pfix,
122-
unl_range_outs=unl_range_outs, unl_range_u=unl_range_u,
123-
unl_range_ins=unl_range_ins, unl_range_p=unl_range_p
124-
125-
(dunl, unl, _) -> begin
126-
outbufs = PreallocationTools.get_tmp.(outcaches, Ref(dunl))
127-
ubuf = PreallocationTools.get_tmp(ucache, dunl)
128-
inbufs = PreallocationTools.get_tmp.(incaches, Ref(dunl))
129-
pbuf = PreallocationTools.get_tmp(pcache, dunl)
130-
131-
# prefill buffers with fixed values
132-
for (buf, fix) in zip(outbufs, outfixs)
133-
buf .= fix
134-
end
135-
ubuf .= ufix
136-
for (buf, fix) in zip(inbufs, infixs)
137-
buf .= fix
138-
end
139-
pbuf .= pfix
157+
fg = compfg(cf)
158+
unlcache = map(d->DiffCache(zeros(d), N), length(freesym))
159+
outcaches = map(d->DiffCache(zeros(d), N), outdim_normalized(cf))
160+
ucache = DiffCache(zeros(dim(cf)), N)
161+
incaches = map(d->DiffCache(zeros(d), N), indim_normalized(cf))
162+
pcache = DiffCache(zeros(pdim(cf)))
163+
164+
fz = (dunl, unl, _) -> begin
165+
# apply the bound conserving transformation
166+
unlbuf = PreallocationTools.get_tmp(unlcache, unl)
167+
unlbuf .= unl
168+
boundT!(unlbuf)
169+
170+
outbufs = PreallocationTools.get_tmp.(outcaches, Ref(dunl))
171+
ubuf = PreallocationTools.get_tmp(ucache, dunl)
172+
inbufs = PreallocationTools.get_tmp.(incaches, Ref(dunl))
173+
pbuf = PreallocationTools.get_tmp(pcache, dunl)
174+
175+
# prefill buffers with fixed values
176+
for (buf, fix) in zip(outbufs, outfixs)
177+
buf .= fix
178+
end
179+
ubuf .= ufix
180+
for (buf, fix) in zip(inbufs, infixs)
181+
buf .= fix
182+
end
183+
pbuf .= pfix
140184

141-
# overwrite nonfixed values
142-
for (buf, mask, range) in zip(outbufs, outfree_ms, unl_range_outs)
143-
buf[mask] .= unl[range]
144-
end
145-
ubuf[ufree_m] .= unl[unl_range_u]
146-
for (buf, mask, range) in zip(inbufs, infree_ms, unl_range_ins)
147-
buf[mask] .= unl[range]
148-
end
149-
pbuf[pfree_m] .= unl[unl_range_p]
185+
# overwrite nonfixed values
186+
for (buf, mask, range) in zip(outbufs, outfree_ms, unl_range_outs)
187+
_overwrite_at_mask!(buf, mask, unlbuf, range)
188+
end
189+
_overwrite_at_mask!(ubuf, ufree_m, unlbuf, unl_range_u)
190+
for (buf, mask, range) in zip(inbufs, infree_ms, unl_range_ins)
191+
_overwrite_at_mask!(buf, mask, unlbuf, range)
192+
end
193+
_overwrite_at_mask!(pbuf, pfree_m, unlbuf, unl_range_p)
150194

151-
# view into du buffer for the fg funtion
152-
@views dunl_fg = dunl[1:dim(cf)]
153-
# view into the output buffer for the outputs
154-
@views dunl_out = dunl[dim(cf)+1:end]
195+
# view into du buffer for the fg funtion
196+
@views dunl_fg = dunl[1:dim(cf)]
197+
# view into the output buffer for the outputs
198+
@views dunl_out = dunl[dim(cf)+1:end]
155199

156-
# this fills the second half of the du buffer with the fixed and current outputs
157-
dunl_out .= RecursiveArrayTools.ArrayPartition(outbufs...)
158-
# execute fg to fill dunl and outputs
159-
fg(outbufs, dunl_fg, ubuf, inbufs, pbuf, t)
200+
# this fills the second half of the du buffer with the fixed and current outputs
201+
dunl_out .= RecursiveArrayTools.ArrayPartition(outbufs...)
202+
# execute fg to fill dunl and outputs
203+
fg(outbufs, dunl_fg, ubuf, inbufs, pbuf, t)
160204

161-
# calculate the residual for the second half ov the dunl buf, the outputs
162-
dunl_out .= dunl_out .- RecursiveArrayTools.ArrayPartition(outbufs...)
163-
nothing
164-
end
205+
# calculate the residual for the second half ov the dunl buf, the outputs
206+
dunl_out .= dunl_out .- RecursiveArrayTools.ArrayPartition(outbufs...)
207+
nothing
165208
end
166209

167210
nlf = NonlinearFunction(fz; resid_prototype=zeros(Neqs), sys=SII.SymbolCache(freesym))
@@ -171,11 +214,21 @@ function initialization_problem(cf::T; t=NaN, verbose=true) where {T<:ComponentM
171214
else
172215
verbose && @info "Initialization problem is overconstrained ($Nfree vars for $Neqs equations). Create NonlinearLeastSquaresProblem for $freesym."
173216
end
174-
NonlinearLeastSquaresProblem(nlf, uguess)
217+
(NonlinearLeastSquaresProblem(nlf, uguess), boundT!)
218+
end
219+
function _overwrite_at_mask!(target, mask, source, range)
220+
src_v = view(source, range)
221+
j = 1
222+
for i in eachindex(target, mask)
223+
if mask[i]
224+
target[i] = src_v[j]
225+
j += 1
226+
end
227+
end
175228
end
176229

177230
"""
178-
initialize_component!(cf::ComponentModel; verbose=true, kwargs...)
231+
initialize_component!(cf::ComponentModel; verbose=true, apply_bound_transformation=true, kwargs...)
179232
180233
Initialize a `ComponentModel` by solving the corresponding `NonlinearLeastSquaresProblem`.
181234
During initialization, everyting which has a `default` value (see [Metadata](@ref)) is considered
@@ -186,9 +239,16 @@ The result is stored in the `ComponentModel` itself. The values of the free vari
186239
in the metadata field `init`.
187240
188241
The `kwargs` are passed to the nonlinear solver.
242+
243+
## Bounds of free variables
244+
When encountering any bounds in the free variables, NetworkDynamics will try to conserve them
245+
by applying a coordinate transforamtion. This behavior can be supressed by setting `apply_bound_transformation`.
246+
The following transformations are used:
247+
- (a, b) intervals where both a and b are positive are transformed to `u^2`/`sqrt(u)`
248+
- (a, b) intervals where both a and b are negative are transformed to `-u^2`/`sqrt(-u)`
189249
"""
190-
function initialize_component!(cf; verbose=true, kwargs...)
191-
prob = initialization_problem(cf; verbose)
250+
function initialize_component!(cf; verbose=true, apply_bound_transformation=true, kwargs...)
251+
prob, boundT! = initialization_problem(cf; verbose, apply_bound_transformation)
192252

193253
if !isempty(prob.u0)
194254
sol = SciMLBase.solve(prob; kwargs...)
@@ -202,14 +262,22 @@ function initialize_component!(cf; verbose=true, kwargs...)
202262
else
203263
verbose && @info "Initialization successful with residual $(LinearAlgebra.norm(sol.resid))"
204264
end
205-
set_init!.(Ref(cf), SII.variable_symbols(sol), sol.u)
265+
# transform back to original space
266+
u = boundT!(copy(sol.u))
267+
set_init!.(Ref(cf), SII.variable_symbols(sol), u)
206268
resid = sol.resid
207269
else
208270
resid = init_residual(cf; recalc=true)
209271
verbose && @info "No free variables! Residual $(LinearAlgebra.norm(resid))"
210272
end
211273

212274
set_metadata!(cf, :init_residual, resid)
275+
276+
broken = broken_bounds(cf)
277+
if !isempty(broken)
278+
@warn "Initialized model has broken bounds for $(broken). Use `dump_initial_state(mode)` \
279+
to inspect further and try to adapt the initial guesses!"
280+
end
213281
cf
214282
end
215283

@@ -250,3 +318,17 @@ function init_residual(cf::T; t=NaN, recalc=false) where {T<:ComponentModel}
250318
LinearAlgebra.norm(get_metadata(cf, :init_residual))
251319
end
252320
end
321+
322+
function broken_bounds(cf)
323+
allsyms = vcat(sym(cf), psym(cf), insym_all(cf), outsym_flat(cf), obssym(cf))
324+
boundsyms = filter(s -> has_bounds(cf, s), allsyms)
325+
bounds = get_bounds.(Ref(cf), boundsyms)
326+
vals = get_initial_state(cf, boundsyms)
327+
broken = filter(i -> !bounds_satisfied(vals[i], bounds[i]), 1:length(bounds))
328+
boundsyms[broken]
329+
end
330+
331+
function bounds_satisfied(val, bounds)
332+
@assert length(bounds) == 2
333+
!isnothing(val) && !isnan(val) && first(bounds) val last(bounds)
334+
end

0 commit comments

Comments
 (0)