Skip to content

Commit adb7e39

Browse files
authored
Fix varargs IPO test (#50)
1 parent a06aeb2 commit adb7e39

File tree

11 files changed

+53
-21
lines changed

11 files changed

+53
-21
lines changed

Manifest.toml

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/analysis/ADAnalyzer.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@ struct ADAnalyzer <: Compiler.AbstractInterpreter
2727
end
2828

2929
Compiler.InferenceParams(interp::ADAnalyzer) = Compiler.InferenceParams()
30-
Compiler.OptimizationParams(interp::ADAnalyzer) = Compiler.OptimizationParams()
30+
Compiler.OptimizationParams(interp::ADAnalyzer) = Compiler.OptimizationParams(;
31+
assume_fatal_throw = true,
32+
compilesig_invokes = false
33+
)
3134
Compiler.get_inference_world(interp::ADAnalyzer) = interp.world
3235
Compiler.get_inference_cache(interp::ADAnalyzer) = interp.inf_cache
3336
Compiler.cache_owner(::ADAnalyzer) = ADCache()
@@ -60,6 +63,8 @@ end
6063
struct AnalyzedSource
6164
ir::Compiler.IRCode
6265
inline_cost::Compiler.InlineCostType
66+
nargs::UInt
67+
isva::Bool
6368
end
6469

6570
@override function Compiler.result_edges(interp::ADAnalyzer, caller::InferenceState)
@@ -70,7 +75,7 @@ end
7075
@override function Compiler.transform_result_for_cache(interp::ADAnalyzer, result::InferenceResult, edges::SimpleVector)
7176
ir = result.src.optresult.ir
7277
params = Compiler.OptimizationParams(interp)
73-
return AnalyzedSource(ir, Compiler.compute_inlining_cost(interp, result))
78+
return AnalyzedSource(ir, Compiler.compute_inlining_cost(interp, result), result.src.src.nargs, result.src.src.isva)
7479
end
7580

7681
@override function Compiler.transform_result_for_local_cache(interp::ADAnalyzer, result::InferenceResult)
@@ -79,7 +84,7 @@ end
7984
end
8085
ir = result.src.optresult.ir
8186
params = Compiler.OptimizationParams(interp)
82-
return AnalyzedSource(ir, Compiler.compute_inlining_cost(interp, result))
87+
return AnalyzedSource(ir, Compiler.compute_inlining_cost(interp, result), result.src.src.nargs, result.src.src.isva)
8388
end
8489

8590
function Compiler.retrieve_ir_for_inlining(ci::CodeInstance, result::AnalyzedSource)

src/analysis/ipoincidence.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,13 @@ function apply_linear_incidence(𝕃, ret::PartialStruct, caller::CallerMappingS
155155
return PartialStruct(𝕃, ret.typ, Any[apply_linear_incidence(𝕃, f, caller, mapping) for f in ret.fields])
156156
end
157157

158-
function CalleeMapping(𝕃::Compiler.AbstractLattice, argtypes::Vector{Any}, callee_result::DAEIPOResult, template_argtypes)
158+
function CalleeMapping(𝕃::Compiler.AbstractLattice, argtypes::Vector{Any}, callee_ci::CodeInstance, callee_result::DAEIPOResult, template_argtypes)
159159
applied_scopes = Any[]
160160
coeffs = Vector{Any}(undef, length(callee_result.var_to_diff))
161161
eq_mapping = fill(0, length(callee_result.total_incidence))
162162

163-
process_template!(𝕃, coeffs, eq_mapping, applied_scopes, argtypes, template_argtypes)
163+
va_argtypes = Compiler.va_process_argtypes(𝕃, argtypes, callee_ci.inferred.nargs, callee_ci.inferred.isva)
164+
process_template!(𝕃, coeffs, eq_mapping, applied_scopes, va_argtypes, template_argtypes)
164165

165166
return CalleeMapping(coeffs, eq_mapping, applied_scopes)
166167
end

src/analysis/refiner.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,7 @@ Compiler.cache_owner(::StructuralRefiner) = StructureCache()
5757
end
5858

5959
argtypes = Compiler.collect_argtypes(interp, stmt.args, Compiler.StatementState(nothing, false), irsv)[2:end]
60-
m = Compiler.get_ci_mi(callee_codeinst).def
61-
argtypes = Compiler.va_process_argtypes(Compiler.optimizer_lattice(interp), argtypes, UInt(m.nargs), m.isva)
62-
mapping = CalleeMapping(Compiler.optimizer_lattice(interp), argtypes, callee_result, callee_codeinst.inferred.ir.argtypes)
60+
mapping = CalleeMapping(Compiler.optimizer_lattice(interp), argtypes, callee_codeinst, callee_result, callee_codeinst.inferred.ir.argtypes)
6361
new_rt = apply_linear_incidence(Compiler.optimizer_lattice(interp), callee_result.extended_rt,
6462
CallerMappingState(callee_result, interp.var_to_diff, interp.varclassification, interp.varkinds, VarEqClassification[]), mapping)
6563

@@ -87,7 +85,7 @@ function structural_inc_ddt(var_to_diff::DiffGraph, varclassification::Union{Vec
8785
push!(varclassification, varclassification[v])
8886
end
8987
if varkinds !== nothing
90-
push!(varkinds, Intrinsics.Continuous)
88+
push!(varkinds, varkinds[v])
9189
end
9290
add_edge!(var_to_diff, v, dv)
9391
return dv + 1

src/analysis/structural.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ function _structural_analysis!(ci::CodeInstance, world::UInt)
8080
empty!(ir.argtypes)
8181
(arg_replacements, new_argtypes, nexternalargvars) = flatten_arguments!(compact, old_argtypes, 0, ir.argtypes)
8282
for i = 1:nexternalargvars
83+
# TODO: Need to handle different var kinds for IPO
8384
add_variable!(Argument(i))
8485
end
8586
argtypes = Any[Incidence(new_argtypes[i], i) for i = 1:nexternalargvars]
@@ -212,6 +213,12 @@ function _structural_analysis!(ci::CodeInstance, world::UInt)
212213
return UncompilableIPOResult(warnings, UnsupportedIRException("Saw invalid variable kind (`$kind`) for variable $(var_num) (SSA $ssa)", ir))
213214
end
214215
varkinds[var_num] = kind.val
216+
dv = var_num
217+
while true
218+
dv = var_to_diff[dv]
219+
dv === nothing && break
220+
varkinds[dv] = kind.val
221+
end
215222

216223
scope = argextype(inst.args[4], ir)
217224
if (!isa(scope, Const) || !isa(scope.val, Intrinsics.AbstractScope)) && !is_valid_partial_scope(scope)
@@ -330,7 +337,7 @@ function _structural_analysis!(ci::CodeInstance, world::UInt)
330337
end
331338

332339
callee_argtypes = Any[argextype(stmt.args[i], compact) for i in 2:length(stmt.args)]
333-
mapping = CalleeMapping(Compiler.optimizer_lattice(refiner), callee_argtypes, result, callee_codeinst.inferred.ir.argtypes)
340+
mapping = CalleeMapping(Compiler.optimizer_lattice(refiner), callee_argtypes, callee_codeinst, result, callee_codeinst.inferred.ir.argtypes)
334341
inst[:info] = info = MappingInfo(info, result, mapping)
335342
end
336343

src/transform/codegen/dae_factory.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ function dae_factory_gen(state::TransformationState, ci::CodeInstance, key::Torn
9696

9797
all_states = Int[]
9898
for var = 1:length(result.var_to_diff)
99+
varkind(state, var) == Intrinsics.Continuous || continue
99100
kind = classify_var(result.var_to_diff, key, var)
100101
kind == nothing && continue
101102
numstates[kind] += 1

src/transform/codegen/ode_factory.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ function ode_factory_gen(state::TransformationState, ci::CodeInstance, key::Torn
8585

8686
all_states = Int[]
8787
for var = 1:length(result.var_to_diff)
88+
varkind(state, var) == Intrinsics.Continuous || continue
8889
kind = classify_var(result.var_to_diff, key, var)
8990
kind == nothing && continue
9091
numstates[kind] += 1

src/transform/codegen/rhs.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ function handle_contribution!(ir::Compiler.IRCode, inst::Compiler.Instruction, k
2525
replace_call!(ir, pos, Expr(:call, Base.setindex!, which, sum, slot))
2626
end
2727

28-
function compute_slot_ranges(info::MappingInfo, callee_key, var_assignment, eq_assignment)
28+
function compute_slot_ranges(caller_state::TransformationState, info::MappingInfo, callee_key, var_assignment, eq_assignment)
2929
# Compute the ranges for this child's states in the parent range.
3030
# We rely upon earlier stages of the pipeline having put these adjacent to each other
3131
# and in order. We could just trust that, but because it's a little bit tricky, here
@@ -47,6 +47,7 @@ function compute_slot_ranges(info::MappingInfo, callee_key, var_assignment, eq_a
4747
caller_map = info.mapping.var_coeffs[callee_var]
4848
isa(caller_map, Const) && continue
4949
caller_var = only(rowvals(caller_map.row))-1
50+
varkind(caller_state, caller_var) == Intrinsics.Continuous || continue
5051

5152
callee_kind = classify_var(info.result.var_to_diff, callee_key, callee_var)
5253
callee_kind === nothing && continue
@@ -166,7 +167,7 @@ function rhs_finish!(
166167
push!(stmt.args, in_vars)
167168

168169
# Ordering from tearing is (AssignedDiff, UnassignedDiff, Algebraic, Explicit)
169-
slot_ranges = compute_slot_ranges(info, callee_key, var_assignment, eq_assignment)
170+
slot_ranges = compute_slot_ranges(state, info, callee_key, var_assignment, eq_assignment)
170171
for (arg, range) in zip(arg_range, slot_ranges)
171172
push!(stmt.args, insert_node!(ir, SSAValue(i),
172173
NewInstruction(inst;
@@ -183,6 +184,12 @@ function rhs_finish!(
183184
error()
184185
elseif is_known_invoke(stmt, variable, ir)
185186
varnum = idnum(ir.stmts.type[i])
187+
kind = varkind(state, varnum)
188+
if kind == Intrinsics.Epsilon
189+
replace_call!(ir, SSAValue(i), 0.)
190+
continue
191+
end
192+
@assert kind == Intrinsics.Continuous
186193

187194
assgn = var_assignment[varnum]
188195
if assgn == nothing

src/transform/common.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ function cache_dae_ci!(old_ci, src, debuginfo, abi, owner; rettype=Tuple)
101101
return daef_ci
102102
end
103103

104-
function replace_call!(ir::Union{IRCode,IncrementalCompact}, idx::SSAValue, new_call::Expr)
104+
function replace_call!(ir::Union{IRCode,IncrementalCompact}, idx::SSAValue, @nospecialize(new_call))
105105
@assert !isa(ir[idx][:inst], PhiNode)
106106
ir[idx][:inst] = new_call
107107
ir[idx][:type] = Any

test/basic.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,18 @@ end
127127
dae_sol = solve(DAECProblem(sicm_vars(1.0), (1,) .=> 1.), IDA())
128128
ode_sol = solve(ODECProblem(sicm_vars(1.0), (1,) .=> 1.), Rodas5(autodiff=false))
129129
for sol in (dae_sol, ode_sol)
130-
@test all(map((x,y)->isapprox(x[], y, atol=1e-2), sol.u[:, 1], 1. .+ sol.t))
130+
@test all(map((x,y)->isapprox(x[], y, atol=1e-2), sol.u[1, :], 1. .+ sol.t))
131131
end
132132

133+
#= epsilon =#
134+
function simple_eps()
135+
x = continuous()
136+
always!(ddt(x) -ᵢ x +epsilon())
137+
end
138+
dae_sol = solve(DAECProblem(simple_eps, (1,) .=> 1.), IDA())
139+
ode_sol = solve(ODECProblem(simple_eps, (1,) .=> 1.), Rodas5(autodiff=false))
140+
for sol in (dae_sol, ode_sol)
141+
@test all(map((x,y)->isapprox(x[], y, atol=1e-2), sol[1, :], exp.(sol.t)))
142+
end
133143

134144
end

0 commit comments

Comments
 (0)