Skip to content

Commit e6f621f

Browse files
committed
Support type constructors for IPO
1 parent 5f796d2 commit e6f621f

File tree

3 files changed

+27
-9
lines changed

3 files changed

+27
-9
lines changed

src/analysis/flattening.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ function flatten_argument!(state::FlatteningState, @nospecialize(argt))
9090
line = compact[Compiler.OldSSAValue(1)][:line]
9191
ssa = @insert_instruction_here(compact, line, settings, (:invoke)(nothing, InternalIntrinsics.external_equation)::Eq(eq))
9292
return ssa
93+
elseif isa(argt, Type)
94+
return argt.parameters[1]
9395
elseif isabstracttype(argt) || ismutabletype(argt) || (!isa(argt, DataType) && !isa(argt, PartialStruct))
9496
line = compact[Compiler.OldSSAValue(1)][:line]
9597
ssa = @insert_instruction_here(compact, line, settings, error("Cannot IPO model arg type $argt")::Union{})
@@ -245,6 +247,7 @@ function annotate_variables_and_equations(argtypes::Vector{Any}, map::ArgumentMa
245247
end
246248

247249
init_partialstruct(@nospecialize(T)) = PartialStruct(T, collect(Any, fieldtypes(T)))
250+
init_partialstruct(pstruct::PartialStruct) = pstruct
248251

249252
function find_base(dict::Dict{CompositeIndex}, index::CompositeIndex)
250253
for i in reverse(eachindex(index))

src/transform/reconstruct.jl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -104,23 +104,26 @@ function compute_residual_vectors(f, u, du; t = 1.0, mode=DAE, world=Base.tls_wo
104104
torn_ir = torn_ci.inferred
105105
removed_states = extract_removed_states(state, key, torn_ir, u, du, t)
106106

107-
residuals = zeros(length(u))
108-
p = SciMLBase.NullParameters()
109-
indices = filter(!in(removed_states), eachindex(u))
110-
u_compressed = u[indices]
111-
du_compressed = du[indices]
112-
residuals_compressed = zeros(length(residuals) - length(removed_states))
113-
114107
our_prob = DAECProblem(f, (1,) .=> 1., insert_stmt_debuginfo = true)
115108
sciml_prob = DiffEqBase.get_concrete_problem(our_prob, true)
116109
f_compressed! = sciml_prob.f.f
117-
f_compressed!(residuals_compressed, du_compressed, u_compressed, p, t)
118110

119111
our_prob = DAECProblem(f, (1,) .=> 1., insert_stmt_debuginfo = true, skip_optimizations = true)
120112
sciml_prob = DiffEqBase.get_concrete_problem(our_prob, true)
121113
f_original! = sciml_prob.f.f
122-
f_original!(residuals, du, u, p, t)
123114

115+
residuals = zeros(length(u))
116+
p = SciMLBase.NullParameters()
117+
indices = filter(!in(removed_states), eachindex(u))
118+
u_compressed = u[indices]
119+
du_compressed = du[indices]
120+
121+
n = length(residuals) - length(removed_states)
122+
@assert n 1
123+
residuals_compressed = zeros(n)
124+
f_compressed!(residuals_compressed, du_compressed, u_compressed, p, t)
125+
f_original!(residuals, du, u, p, t)
124126
expanded = expand_residuals(f, residuals_compressed, u, du, t)
127+
125128
return residuals, expanded
126129
end

test/validation.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@ function equation_used_multiple_times!()
7979
apply_equation_on_ddtx_minus_one(eq, x)
8080
end
8181

82+
function equation_with_callable!()
83+
x = continuous()
84+
callable = @noinline Returns(x)
85+
always!(ddt(callable()) - 3.0)
86+
end
87+
8288
@testset "Validation" begin
8389
refresh() # TODO: remove before merge
8490

@@ -149,4 +155,10 @@ end
149155
residuals, expanded_residuals = compute_residual_vectors(equation_used_multiple_times!, u, du)
150156
@test residuals [6.0]
151157
@test residuals expanded_residuals
158+
159+
u = [2.0]
160+
du = [4.0]
161+
residuals, expanded_residuals = compute_residual_vectors(equation_with_callable!, u, du)
162+
@test residuals [1.0]
163+
@test residuals expanded_residuals
152164
end;

0 commit comments

Comments
 (0)