Skip to content

Commit bae2a86

Browse files
committed
Fix handling of functors, replace flatten_parameter!
1 parent a45c5c4 commit bae2a86

File tree

8 files changed

+68
-59
lines changed

8 files changed

+68
-59
lines changed

src/analysis/flattening.jl

Lines changed: 5 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ function flatten_argument!(state::FlatteningState, @nospecialize(argt))
9090
line = compact[Compiler.OldSSAValue(1)][:line]
9191
ssa = @insert_instruction_here(compact, line, settings, (:invoke)(nothing, InternalIntrinsics.external_equation)::Eq(eq))
9292
return ssa
93-
elseif isa(argt, Type)
93+
elseif argt <: Type
9494
return argt.parameters[1]
9595
elseif isabstracttype(argt) || ismutabletype(argt) || (!isa(argt, DataType) && !isa(argt, PartialStruct))
9696
line = compact[Compiler.OldSSAValue(1)][:line]
@@ -112,19 +112,19 @@ function flatten_argument!(state::FlatteningState, @nospecialize(argt))
112112
end
113113
end
114114

115-
function flatten_arguments_for_callee!(compact::IncrementalCompact, map::ArgumentMap, argtypes, 𝕃, line, settings)
115+
function flatten_arguments_for_callee!(compact::IncrementalCompact, map::ArgumentMap, argtypes, args, line, settings, 𝕃 = Compiler.fallback_lattice)
116116
list = Any[]
117117
this = nothing
118-
last_index = Int[]
118+
last_index = CompositeIndex()
119119
for index in map.variables
120120
from = findfirst(j -> get(last_index, j, -1) !== index[j], eachindex(index))::Int
121121
for i in from:length(index)
122122
field = index[i]
123123
if i == 1
124-
this = Argument(2 + field)
124+
this = args[field]
125125
else
126126
thistype = argextype(this, compact)
127-
fieldtype = Compiler.getfield_tfunc(𝕃, Const(field))
127+
fieldtype = Compiler.getfield_tfunc(𝕃, thistype, Const(field))
128128
this = @insert_instruction_here(compact, line, settings, getfield(this, field)::fieldtype)
129129
end
130130
end
@@ -133,42 +133,6 @@ function flatten_arguments_for_callee!(compact::IncrementalCompact, map::Argumen
133133
return list
134134
end
135135

136-
function _flatten_parameter!(𝕃, compact, argtypes, ntharg, line, settings)
137-
list = Any[]
138-
for (argn, argt) in enumerate(argtypes)
139-
if isa(argt, Const)
140-
continue
141-
elseif Base.issingletontype(argt)
142-
continue
143-
elseif Base.isprimitivetype(argt) || isa(argt, Incidence)
144-
push!(list, ntharg(argn))
145-
elseif argt === equation || isa(argt, Eq)
146-
continue
147-
elseif isa(argt, Type) && argt <: Intrinsics.AbstractScope
148-
continue
149-
elseif isabstracttype(argt) || ismutabletype(argt) || (!isa(argt, DataType) && !isa(argt, PartialStruct))
150-
continue
151-
else
152-
if !isa(argt, PartialStruct) && Base.datatype_fieldcount(argt) === nothing
153-
continue
154-
end
155-
this = ntharg(argn)
156-
nthfield(i) = @insert_instruction_here(compact, line, settings, getfield(this, i)::Compiler.getfield_tfunc(𝕃, argextype(this, compact), Const(i)))
157-
if isa(argt, PartialStruct)
158-
fields = _flatten_parameter!(𝕃, compact, argt.fields, nthfield, line, settings)
159-
else
160-
fields = _flatten_parameter!(𝕃, compact, fieldtypes(argt), nthfield, line, settings)
161-
end
162-
append!(list, fields)
163-
end
164-
end
165-
return list
166-
end
167-
168-
function flatten_parameter!(𝕃, compact, argtypes, ntharg, line, settings)
169-
return @insert_instruction_here(compact, line, settings, tuple(_flatten_parameter!(𝕃, compact, argtypes, ntharg, line, settings)...)::Tuple)
170-
end
171-
172136
remove_variable_and_equation_annotations(argtypes) = Any[widenconst(T) for T in argtypes]
173137

174138
function annotate_variables_and_equations(argtypes::Vector{Any}, map::ArgumentMap)

src/analysis/structural.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,11 @@ function _structural_analysis!(ci::CodeInstance, world::UInt, settings::Settings
376376
# Rewrite to flattened ABI
377377
compact[SSAValue(i)] = nothing
378378
compact.result_idx -= 1
379-
new_args = _flatten_parameter!(Compiler.optimizer_lattice(refiner), compact, callee_codeinst.inferred.ir.argtypes, arg->stmt.args[arg+1], line, settings)
379+
callee_argtypes = callee_codeinst.inferred.ir.argtypes
380+
callee_argmap = ArgumentMap(callee_argtypes)
381+
args = @view(stmt.args[2:end])
382+
𝕃 = Compiler.optimizer_lattice(refiner)
383+
new_args = flatten_arguments_for_callee!(compact, callee_argmap, callee_argtypes, args, line, settings, 𝕃)
380384
new_call = insert_instruction_here!(compact, settings, @__SOURCE__,
381385
NewInstruction(Expr(:invoke, (StructuralSSARef(compact.result_idx), callee_codeinst), new_args...), stmtype, info, line, stmtflags))
382386
compact.ssa_rename[compact.idx - 1] = new_call

src/transform/codegen/dae_factory.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ function dae_factory_gen(state::TransformationState, ci::CodeInstance, key::Unio
147147
argt = Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64}
148148
sicm = ()
149149
if settings.skip_optimizations
150-
daef_ci = rhs_finish_noopt!(state, ci, key, world, settings)
150+
daef_ci = rhs_finish_noopt!(state, ci, key, world, settings; opaque_closure = true)
151151
oc = sciml_to_internal_abi_noopt!(copy(ci.inferred.ir), state, daef_ci, settings)
152152
else
153153
# TODO: We should not have to recompute this here
@@ -164,8 +164,12 @@ function dae_factory_gen(state::TransformationState, ci::CodeInstance, key::Unio
164164
@assert sicm_ci !== nothing
165165

166166
line = result.ir[SSAValue(1)][:line]
167-
param_list = flatten_parameter!(Compiler.fallback_lattice, compact, ci.inferred.ir.argtypes[1:end], argn->Argument(2+argn), line, settings)
168-
sicm = @insert_instruction_here(compact, line, settings, invoke(param_list, sicm_ci)::Tuple)
167+
callee_argtypes = ci.inferred.ir.argtypes
168+
callee_argmap = ArgumentMap(callee_argtypes)
169+
args = Argument.(2 .+ eachindex(callee_argtypes))
170+
new_args = flatten_arguments_for_callee!(compact, callee_argmap, callee_argtypes, args, line, settings)
171+
list = @insert_instruction_here(compact, line, settings, tuple(new_args...)::Tuple)
172+
sicm = @insert_instruction_here(compact, line, settings, invoke(list, sicm_ci)::Tuple)
169173
end
170174

171175
daef_ci = rhs_finish!(state, ci, key, world, settings, 1)

src/transform/codegen/init_factory.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@ function init_uncompress_gen!(compact::Compiler.IncrementalCompact, result::DAEI
2727
@assert sicm_ci !== nothing
2828

2929
line = result.ir[SSAValue(1)][:line]
30-
param_list = flatten_parameter!(Compiler.fallback_lattice, compact, ci.inferred.ir.argtypes[1:end], argn->Argument(2+argn), line, settings)
30+
callee_argtypes = ci.inferred.ir.argtypes
31+
callee_argmap = ArgumentMap(callee_argtypes)
32+
args = Argument.(2 .+ eachindex(callee_argtypes))
33+
new_args = flatten_arguments_for_callee!(compact, callee_argmap, callee_argtypes, args, line, settings)
34+
param_list = @insert_instruction_here(compact, line, settings, tuple(new_args...)::Tuple)
3135
sicm = @insert_instruction_here(compact, line, settings, invoke(param_list, sicm_ci)::Tuple)
3236
else
3337
sicm = ()

src/transform/codegen/ode_factory.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,12 @@ function ode_factory_gen(state::TransformationState, ci::CodeInstance, key::Torn
7070
@assert sicm_ci !== nothing
7171

7272
line = result.ir[SSAValue(1)][:line]
73-
param_list = flatten_parameter!(Compiler.fallback_lattice, returned_ic, ci.inferred.ir.argtypes[1:end], argn->Argument(2+argn), line, settings)
74-
sicm_state = @insert_instruction_here(returned_ic, line, settings, (:call)(invoke, param_list, sicm_ci)::Tuple)
73+
callee_argtypes = ci.inferred.ir.argtypes
74+
callee_argmap = ArgumentMap(callee_argtypes)
75+
args = Argument.(2 .+ eachindex(callee_argtypes))
76+
new_args = flatten_arguments_for_callee!(returned_ic, callee_argmap, callee_argtypes, args, line, settings)
77+
param_list = @insert_instruction_here(returned_ic, line, settings, tuple(new_args...)::Tuple)
78+
sicm_state = @insert_instruction_here(returned_ic, line, settings, invoke(param_list, sicm_ci)::Tuple)
7579
else
7680
sicm_state = ()
7781
end

src/transform/reconstruct.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ derivative may differ between the unoptimized and optimized versions.
9494
function compute_residual_vectors(f, u, du; t = 1.0, mode=DAE, world=Base.tls_world_age())
9595
@assert mode === DAE # TODO: support ODEs
9696
settings = Settings(; mode, insert_stmt_debuginfo = true)
97-
ci = _code_ad_by_type(Tuple{typeof(f)}; world)
97+
tt = Base.signature_type(f, ())
98+
ci = _code_ad_by_type(tt; world)
9899
result = @code_structure result=true mode=mode world=world f()
99100
structure = make_structure_from_ipo(result)
100101
state = TransformationState(result, structure)

src/transform/unoptimized.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ function rhs_finish_noopt!(
1313
world::UInt,
1414
settings::Settings,
1515
equation_to_residual_mapping = 1:length(state.structure.eq_to_diff),
16-
variable_to_state_mapping = map_variables_to_states(state))
16+
variable_to_state_mapping = map_variables_to_states(state);
17+
opaque_closure)
1718

1819
(; result, structure) = state
1920
result_ci = find_matching_ci(ci -> ci.owner === key, ci.def, world)
@@ -23,12 +24,14 @@ function rhs_finish_noopt!(
2324

2425
ir = copy(result.ir)
2526
src = ci.inferred::AnalyzedSource
26-
argrange = 2:src.nargs
27-
slotnames = Symbol[:captures]
28-
argtypes = Any[Tuple]
27+
argrange = 1:src.nargs
2928
# Original arguments.
30-
append!(slotnames, src.slotnames[argrange])
31-
append!(argtypes, remove_variable_and_equation_annotations(ir.argtypes[argrange]))
29+
slotnames = src.slotnames[argrange]
30+
argtypes = remove_variable_and_equation_annotations(ir.argtypes)
31+
if opaque_closure
32+
slotnames[1] = :captures
33+
argtypes[1] = Tuple
34+
end
3235
# Additional ABI arguments.
3336
push!(slotnames, :out, :du, :u, :residuals, :states, :t)
3437
push!(argtypes, Vector{Float64}, Vector{Float64}, Vector{Float64}, Vector{Int}, Vector{Int}, Float64)
@@ -88,7 +91,7 @@ function rhs_finish_noopt!(
8891
inst[:stmt] = 0.0
8992
elseif isexpr(stmt, :invoke)
9093
info = inst[:info]::MappingInfo
91-
callee_ci, callee_f, original_args = stmt.args[1]::CodeInstance, stmt.args[2], @view stmt.args[3:end]
94+
callee_ci, args = stmt.args[1]::CodeInstance, @view stmt.args[2:end]
9295
callee_result = structural_analysis!(callee_ci, world, settings)
9396
callee_structure = make_structure_from_ipo(callee_result)
9497
callee_state = TransformationState(callee_result, callee_structure)
@@ -102,9 +105,8 @@ function rhs_finish_noopt!(
102105
end
103106
callee_states = [get(variable_to_state_mapping, i, -1) for i in caller_variables]
104107

105-
callee_daef_ci = rhs_finish_noopt!(callee_state, callee_ci, UnoptimizedKey(), world, settings, callee_residuals, callee_states)
106-
call = @insert_instruction_here(compact, line, settings, (:invoke)(callee_daef_ci, callee_f,
107-
original_args...,
108+
callee_daef_ci = rhs_finish_noopt!(callee_state, callee_ci, UnoptimizedKey(), world, settings, callee_residuals, callee_states; opaque_closure = false)
109+
call = @insert_instruction_here(compact, line, settings, (:invoke)(callee_daef_ci, args...,
108110
out,
109111
du,
110112
u,

test/validation.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,19 @@ function equation_with_callable!()
8585
always!(ddt(callable()) - 3.0)
8686
end
8787

88+
@noinline apply_equation!(lhs, rhs) = always!(lhs - rhs)
89+
function nonlinear_argument!()
90+
x = continuous()
91+
apply_equation!(ddt(x), sin(x))
92+
end
93+
94+
struct WithParameter{N} end
95+
@noinline (::WithParameter{N})(eq, x) where {N} = eq(ddt(x) - N)
96+
function callable_with_type_parameter!()
97+
eq, x = new_equation_and_variable()
98+
WithParameter{3}()(eq, x)
99+
end
100+
88101
@testset "Validation" begin
89102
refresh() # TODO: remove before merge
90103

@@ -132,6 +145,12 @@ end
132145
@test all(>(0), residuals)
133146
@test residuals expanded_residuals
134147

148+
u = [2.0]
149+
du = [4.0]
150+
residuals, expanded_residuals = compute_residual_vectors(nonlinear_argument!, u, du)
151+
@test residuals du .- sin.(u)
152+
@test residuals expanded_residuals
153+
135154
u = [0.0]
136155
du = [2.0]
137156
residuals, expanded_residuals = compute_residual_vectors(external_equation!, u, du)
@@ -161,4 +180,11 @@ end
161180
residuals, expanded_residuals = compute_residual_vectors(equation_with_callable!, u, du)
162181
@test residuals [1.0]
163182
@test residuals expanded_residuals
183+
184+
u = [2.0]
185+
du = [3.0]
186+
refresh(); residuals, expanded_residuals = compute_residual_vectors(callable_with_type_parameter!, u, du)
187+
@test residuals [0.0]
188+
# XXX: Residuals from the optimized pipeline are wrong (3.0 instead of 0.0)
189+
@test_broken residuals expanded_residuals
164190
end;

0 commit comments

Comments
 (0)