Skip to content

Commit 25b56d7

Browse files
committed
working codegen
1 parent 2fcb9c9 commit 25b56d7

File tree

4 files changed

+18
-6
lines changed

4 files changed

+18
-6
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -899,14 +899,14 @@ get_callback(prob::BVProblem) = error("BVP solvers do not support callbacks.")
899899
"""
900900
function generate_function_bc(sys::ODESystem, u0, u0_idxs, tspan; kwargs...)
901901
iv = get_iv(sys)
902-
sts = get_unknowns(sys)
903-
ps = get_ps(sys)
902+
sts = unknowns(sys)
903+
ps = parameters(sys)
904904
np = length(ps)
905905
ns = length(sts)
906906
stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
907907
pidxmap = Dict([v => i for (i, v) in enumerate(ps)])
908908

909-
@variables sol(..)[1:ns] p[1:np]
909+
@variables sol(..)[1:ns]
910910

911911
conssys = get_constraintsystem(sys)
912912
cons = Any[]
@@ -931,7 +931,7 @@ function generate_function_bc(sys::ODESystem, u0, u0_idxs, tspan; kwargs...)
931931
exprs = vcat(init_conds, cons)
932932
_p = reorder_parameters(sys, ps)
933933

934-
build_function_wrapper(sys, exprs, sol, _p..., t; kwargs...)
934+
build_function_wrapper(sys, exprs, sol, _p..., t; output_type = Array, kwargs...)
935935
end
936936

937937
"""

src/systems/diffeqs/odesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,6 @@ function process_constraint_system(constraints::Vector{Equation}, sts, ps, iv; c
687687
constraintsts = OrderedSet()
688688
constraintps = OrderedSet()
689689

690-
# Hack? to extract parameters from callable variables in constraints.
691690
for cons in constraints
692691
collect_vars!(constraintsts, constraintps, cons, iv)
693692
end
@@ -712,5 +711,6 @@ function process_constraint_system(constraints::Vector{Equation}, sts, ps, iv; c
712711
end
713712
end
714713

714+
@show constraints
715715
ConstraintsSystem(constraints, collect(constraintsts), collect(constraintps); name = consname)
716716
end

src/systems/optimization/constraints_system.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ function ConstraintsSystem(constraints, unknowns, ps;
123123
name === nothing &&
124124
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
125125

126-
cstr = value.(Symbolics.canonical_form.(scalarize(constraints)))
126+
cstr = value.(Symbolics.canonical_form.(vcat(scalarize(constraints)...)))
127127
unknowns′ = value.(scalarize(unknowns))
128128
ps′ = value.(ps)
129129

test/odesystem.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1670,4 +1670,16 @@ end
16701670
@test_throws ArgumentError @mtkbuild sys = ODESystem(eqs, t; constraints = cons)
16711671
cons = [x(t) * v ~ 3]
16721672
@test_throws ArgumentError @mtkbuild sys = ODESystem(eqs, t; constraints = cons) # Need time argument.
1673+
1674+
# Test array variables
1675+
@variables x(..)[1:5]
1676+
mat = [1 2 0 3 2
1677+
0 0 3 2 0
1678+
0 1 3 0 4
1679+
2 0 0 2 1
1680+
0 0 2 0 5]
1681+
eqs = D(x(t)) ~ mat * x(t)
1682+
cons = [x(3) ~ [2,3,3,5,4]]
1683+
@mtkbuild ode = ODESystem(D(x(t)) ~ mat * x(t), t; constraints = cons)
1684+
@test length(constraints(ModelingToolkit.get_constraintsystem(ode))) == 5
16731685
end

0 commit comments

Comments
 (0)