Skip to content

Commit c3f3191

Browse files
committed
Finalize ODE codegen
1 parent ff71be5 commit c3f3191

File tree

6 files changed

+53
-53
lines changed

6 files changed

+53
-53
lines changed

src/DAECompiler.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ module DAECompiler
1010
using Core.IR
1111
using SciMLBase
1212
using AutoHashEquals
13+
using LinearAlgebra: LinearAlgebra
1314

1415
include("utils.jl")
1516
include("intrinsics.jl")

src/problem_interface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,9 @@ function ODECProblem(f, tspan::Tuple{Real, Real} = (0., 1.); guesses = nothing,
6666
end
6767

6868
function DiffEqBase.get_concrete_problem(prob::ODECProblem, isadaptive; kwargs...)
69-
odef = factory(Val(Settings(mode=prob.init === nothing ? ODE : ODENoInit)), prob.f)
69+
(odef, n) = factory(Val(Settings(mode=prob.init === nothing ? ODE : ODENoInit)), prob.f)
7070

71-
u0 = zeros(length(prob.init))
71+
u0 = zeros(n)
7272

7373
if prob.init !== nothing
7474
for (which, val) in prob.init

src/transform/codegen/ode_factory.jl

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,35 @@ Given an IR value `arg` that corresponds to `u` in SciML's ODE ABI, split it int
55
the DAECompiler internal ABI.
66
"""
77
function sciml_ode_split_u!(compact, line, arg, numstates)
8-
nassgn = numstates[AssignedDiff]
98
ntotalstates = numstates[AssignedDiff] + numstates[UnassignedDiff] + numstates[Algebraic] + numstates[AlgebraicDerivative]
109

11-
u_mm = @insert_node_here compact line view(arg, 1:nassgn)::VectorViewType
12-
u_unassgn = @insert_node_here compact line view(arg, (nassgn+1):(nassgn+numstates[UnassignedDiff]))::VectorViewType
13-
alg = @insert_node_here compact line view(arg, (nassgn+numstates[UnassignedDiff]+1):ntotalstates)::VectorViewType
10+
u_mm = @insert_node_here compact line view(arg,
11+
1:numstates[AssignedDiff])::VectorViewType
12+
u_unassgn = @insert_node_here compact line view(arg,
13+
(numstates[AssignedDiff] + 1):(numstates[AssignedDiff] + numstates[UnassignedDiff]))::VectorViewType
14+
alg = @insert_node_here compact line view(arg,
15+
(numstates[AssignedDiff] + numstates[UnassignedDiff] + 1):(numstates[AssignedDiff] + numstates[UnassignedDiff] + numstates[Algebraic]))::VectorViewType
16+
alg_derv = @insert_node_here compact line view(arg,
17+
(numstates[AssignedDiff] + numstates[UnassignedDiff] + numstates[Algebraic] + 1):ntotalstates)::VectorViewType
18+
19+
return (u_mm, u_unassgn, alg, alg_derv)
20+
end
1421

15-
return (u_mm, u_unassgn, alg)
22+
function generate_ode_mass_matrix(nd, na)
23+
n = nd + na
24+
mass_matrix = zeros(Float64, n, n)
25+
for i in 1:nd
26+
mass_matrix[i, i] = 1.0
27+
end
28+
return mass_matrix
1629
end
1730

18-
function make_odefunction(f)
19-
ODEFunction(f)
31+
function make_odefunction(f, mass_matrix = LinearAlgebra.I, initf = nothing)
32+
ODEFunction(f; mass_matrix, initialization_data = (initf === nothing ? nothing : initialization_data_ode(initf)))
2033
end
2134

22-
function make_odefunction(f, initf)
23-
ODEFunction(f; initialization_data = SciMLBase.OverrideInitData(NonlinearProblem((args...)->nothing, nothing, nothing), nothing, initf, nothing, nothing))
35+
function initialization_data_ode(initf)
36+
return SciMLBase.OverrideInitData(NonlinearProblem((args...)->nothing, nothing, nothing), nothing, initf, nothing, nothing)
2437
end
2538

2639
"""
@@ -39,8 +52,8 @@ end
3952
```
4053
4154
"""
42-
function ode_factory_gen(result::DAEIPOResult, ci::CodeInstance, key::TornCacheKey, world::UInt, init_key::Union{TornCacheKey, Nothing})
43-
@ccall jl_safe_printf("$key\n"::Cstring)::Cvoid
55+
function ode_factory_gen(state::TransformationState, ci::CodeInstance, key::TornCacheKey, world::UInt, init_key::Union{TornCacheKey, Nothing})
56+
result = state.result
4457
torn_ci = find_matching_ci(ci->isa(ci.owner, TornIRSpec) && ci.owner.key == key, ci.def, world)
4558
torn_ir = torn_ci.inferred
4659

@@ -65,7 +78,7 @@ function ode_factory_gen(result::DAEIPOResult, ci::CodeInstance, key::TornCacheK
6578
sicm = ()
6679
end
6780

68-
odef_ci = rhs_finish!(result, ci, key, world, 1)
81+
odef_ci = rhs_finish!(state, ci, key, world, 1)
6982

7083
# Create a small opaque closure to adapt from SciML ABI to our own internal
7184
# ABI
@@ -76,7 +89,6 @@ function ode_factory_gen(result::DAEIPOResult, ci::CodeInstance, key::TornCacheK
7689
for var = 1:length(result.var_to_diff)
7790
kind = classify_var(result.var_to_diff, key, var)
7891
kind == nothing && continue
79-
@ccall jl_safe_printf("$kind\n"::Cstring)::Cvoid
8092
numstates[kind] += 1
8193
(kind != AlgebraicDerivative) && push!(all_states, var)
8294
end
@@ -101,17 +113,17 @@ function ode_factory_gen(result::DAEIPOResult, ci::CodeInstance, key::TornCacheK
101113

102114
# out_du_mm, out_eq, in_u_mm, in_u_unassgn, in_alg
103115
nassgn = numstates[AssignedDiff]
104-
ntotalstates = numstates[AssignedDiff] + numstates[UnassignedDiff] + numstates[Algebraic]
116+
ntotalstates = numstates[AssignedDiff] + numstates[UnassignedDiff] + numstates[Algebraic] + numstates[AlgebraicDerivative]
105117
out_du_mm = @insert_node_here oc_compact line view(du, 1:nassgn)::VectorViewType
106118
out_eq = @insert_node_here oc_compact line view(du, (nassgn+1):ntotalstates)::VectorViewType
107119

108-
(in_u_mm, in_u_unassgn, in_alg) = sciml_ode_split_u!(oc_compact, line, u, numstates)
120+
(in_u_mm, in_u_unassgn, in_alg, in_alg_derv) = sciml_ode_split_u!(oc_compact, line, u, numstates)
109121

110122
# Call DAECompiler-generated RHS with internal ABI
111123
oc_sicm = @insert_node_here oc_compact line getfield(self, 1)::Tuple
112124

113125
# N.B: The ordering of arguments should match the ordering in the StateKind enum
114-
@insert_node_here oc_compact line (:invoke)(odef_ci, oc_sicm, (), in_u_mm, in_u_unassgn, in_alg, out_du_mm, out_eq, t)::Nothing
126+
@insert_node_here oc_compact line (:invoke)(odef_ci, oc_sicm, (), in_u_mm, in_u_unassgn, in_alg, in_alg_derv, out_du_mm, out_eq, t)::Nothing
115127

116128
# Return
117129
@insert_node_here oc_compact line (return)::Union{}
@@ -129,14 +141,14 @@ function ode_factory_gen(result::DAEIPOResult, ci::CodeInstance, key::TornCacheK
129141

130142
new_oc = @insert_node_here compact line (:new_opaque_closure)(argt, Union{}, Nothing, true, oc_source_method, sicm)::Core.OpaqueClosure true
131143

132-
if init_key !== nothing
133-
initf = init_uncompress_gen!(compact, result, ci, init_key, key, world)
134-
odef = @insert_node_here compact line make_odefunction(new_oc, initf)::ODEFunction true
135-
else
136-
odef = @insert_node_here compact line make_odefunction(new_oc)::ODEFunction true
137-
end
144+
nd = numstates[AssignedDiff] + numstates[UnassignedDiff]
145+
na = numstates[Algebraic] + numstates[AlgebraicDerivative]
146+
mass_matrix = na == 0 ? GlobalRef(LinearAlgebra, :I) : @insert_node_here compact line generate_ode_mass_matrix(nd, na)::Matrix{Float64}
147+
initf = init_key !== nothing ? init_uncompress_gen!(compact, result, ci, init_key, key, world) : nothing
148+
odef = @insert_node_here compact line make_odefunction(new_oc, mass_matrix, initf)::ODEFunction true
138149

139-
@insert_node_here compact line (return odef)::Core.OpaqueClosure true
150+
odef_and_n = @insert_node_here compact line tuple(odef, nd + na)::Tuple true
151+
@insert_node_here compact line (return odef_and_n)::Core.OpaqueClosure true
140152

141153
ir_factory = Compiler.finish(compact)
142154

test/basic.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ end
2020
oneeq!()
2121
sol = solve(DAECProblem(oneeq!, (1,) .=> 1.), IDA())
2222
@test all(map((x,y)->isapprox(x[], y, atol=1e-2), sol.u[:, 1], exp.(sol.t)))
23+
sol = solve(ODECProblem(oneeq!, (1,) .=> 1.), Rodas5(autodiff=false))
24+
@test all(map((x,y)->isapprox(x[], y, atol=1e-2), sol.u[:, 1], exp.(sol.t)))
2325

2426
#= Initial Condition =#
2527
@noinline function oneeq_ic!()
@@ -33,6 +35,8 @@ oneeq_ic!()
3335
# TODO: Sundials is broken and doesn't respect the custom initialization (https://github.com/SciML/Sundials.jl/issues/469)
3436
sol = solve(DAECProblem(oneeq_ic!), DFBDF(autodiff=false))
3537
@test all(map((x,y)->isapprox(x[], y, atol=1e-2), sol.u[:, 1], exp.(sol.t)))
38+
sol = solve(ODECProblem(oneeq_ic!), Rodas5(autodiff=false))
39+
@test all(map((x,y)->isapprox(x[], y, atol=1e-2), sol.u[:, 1], exp.(sol.t)))
3640

3741
#= Pantelides =#
3842
function pantelides()
@@ -45,6 +49,8 @@ end
4549
pantelides()
4650
sol = solve(DAECProblem(pantelides, (1,) .=> 0.), DFBDF(autodiff=false))
4751
@test all(map((x,y)->isapprox(x[], y, atol=1e-2), sol.u[:, 1], sol.t))
52+
sol = solve(ODECProblem(pantelides, (1,) .=> 0.), Rodas5(autodiff=false))
53+
@test all(map((x,y)->isapprox(x[], y, atol=1e-2), sol.u[:, 1], sol.t))
4854

4955
#= Structural Singularity Removal =#
5056
function ssrm()
@@ -58,6 +64,8 @@ end
5864
ssrm()
5965
sol = solve(DAECProblem(ssrm, (1,) .=> 1.), DFBDF(autodiff=false))
6066
@test all(map((x,y)->isapprox(x[], y, atol=1e-2), sol.u[:, 1], exp.(0.5sol.t)))
67+
sol = solve(ODECProblem(ssrm, (1,) .=> 1.), Rodas5(autodiff=false))
68+
@test all(map((x,y)->isapprox(x[], y, atol=1e-2), sol.u[:, 1], exp.(0.5sol.t)))
6169

6270
#= Pantelides from init =#
6371
function pantelides_from_init()
@@ -66,4 +74,4 @@ function pantelides_from_init()
6674
initial!(x -ddt(ddt(y)))
6775
end
6876

69-
end
77+
end

test/ipo.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using DAECompiler
55
using DAECompiler.Intrinsics
66
using Sundials
77
using SciMLBase
8+
using OrdinaryDiffEq
89

910
#= Basic IPO: We need to read the incidence of the contained `-` =#
1011
@noinline function onecall!()
@@ -15,6 +16,8 @@ end
1516
onecall!()
1617
sol = solve(DAECProblem(onecall!, (1,) .=> 1.), IDA())
1718
@test all(map((x,y)->isapprox(x[], y, atol=1e-2), sol[1, :], exp.(sol.t)))
19+
sol = solve(ODECProblem(onecall!, (1,) .=> 1.), Rodas5(autodiff=false))
20+
@test all(map((x,y)->isapprox(x[], y, atol=1e-2), sol[1, :], exp.(sol.t)))
1821

1922
#= + Contained Equations =#
2023
function twocall!()
@@ -26,5 +29,8 @@ twocall!()
2629
sol = solve(DAECProblem(twocall!, (1, 2) .=> 1.), IDA())
2730
@test all(map((x,y)->isapprox(x[], y, atol=1e-2), sol[1, :], exp.(sol.t)))
2831
@test all(map((x,y)->isapprox(x[], y, atol=1e-2), sol[2, :], exp.(sol.t)))
32+
sol = solve(ODECProblem(twocall!, (1, 2) .=> 1.), Rodas5(autodiff=false))
33+
@test all(map((x,y)->isapprox(x[], y, atol=1e-2), sol[1, :], exp.(sol.t)))
34+
@test all(map((x,y)->isapprox(x[], y, atol=1e-2), sol[2, :], exp.(sol.t)))
2935

3036
end

test/ode.jl

Lines changed: 0 additions & 27 deletions
This file was deleted.

0 commit comments

Comments
 (0)