Skip to content

Commit 57f65d7

Browse files
committed
Adjust residual sign based on solved variable coefficient
1 parent 4402b42 commit 57f65d7

File tree

2 files changed

+54
-6
lines changed

2 files changed

+54
-6
lines changed

src/transform/reconstruct.jl

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@ function expand_residuals(state::TransformationState, key::TornCacheKey, compres
22
(; result, structure) = state
33
expanded = Float64[]
44
i = 1
5-
# TODO: Remove this and the `key` argument if unused
6-
# var_eq_matching = matching_for_key(state, key)
7-
# (slot_assignments, var_assignment, _) = assign_slots(state, key, var_eq_matching)
5+
var_eq_matching = matching_for_key(state, key)
86
for (eq, incidence) in enumerate(result.total_incidence)
97
if !is_var_part_known_linear(incidence)
10-
push!(expanded, compressed[i])
8+
sign = infer_residual_sign(result, eq, var_eq_matching)
9+
push!(expanded, sign * compressed[i])
1110
i += 1
1211
continue
1312
end
@@ -21,7 +20,7 @@ function expand_residuals(state::TransformationState, key::TornCacheKey, compres
2120
is_diff = is_differential_variable(structure, var)
2221
source = ifelse(is_diff, du, u)
2322
slot = count(v -> is_differential_variable(structure, v) == is_diff, 1:var)
24-
eq == var && !is_diff && (i += 1)
23+
var === invview(var_eq_matching)[eq] && !is_diff && (i += 1)
2524
value = source[slot]
2625
end
2726
residual += value * coeff
@@ -32,6 +31,22 @@ function expand_residuals(state::TransformationState, key::TornCacheKey, compres
3231
return expanded
3332
end
3433

34+
function infer_residual_sign(result::DAEIPOResult, eq::Int, var_eq_matching)
35+
# If a linear solved term appears with a positive coefficient,
36+
# the residual will be taken as the negative of the value provided to `always!`.
37+
# For example: ẋ₁ - x₁x₂ = 0
38+
# -ẋ₁ = -x₁x₂
39+
# ẋ₁ = -x₁x₂/-1
40+
# ẋ₁ = x₁x₂
41+
# 0 = x₁x₂ - ẋ₁ <-- residual
42+
incidence = result.total_incidence[eq]
43+
var = invview(var_eq_matching)[eq]
44+
isa(var, Int) || return 1
45+
coeff = incidence.row[var + 1]
46+
isa(coeff, Float64) || return -1
47+
return -sign(coeff)
48+
end
49+
3550
function is_differential_variable(structure::DAESystemStructure, var)
3651
structure.var_to_diff[var] !== nothing && return false
3752
return invview(structure.var_to_diff)[var] !== nothing && return true
@@ -78,7 +93,7 @@ derivative may differ between the unoptimized and optimized versions.
7893
"""
7994
function compute_residual_vectors(f, u, du; t = rand(), mode=DAE, world=Base.tls_world_age())
8095
@assert mode === DAE # TODO: support ODEs
81-
settings = Settings(; mode)
96+
settings = Settings(; mode, insert_stmt_debuginfo = true)
8297
ci = _code_ad_by_type(Tuple{typeof(f)}; world)
8398
result = @code_structure result=true mode=mode world=world f()
8499
structure = make_structure_from_ipo(result)

test/validation.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,26 @@ end
2727
always!(ddt(x) - x)
2828
end
2929

30+
@noinline function sin!()
31+
x = continuous()
32+
always!(ddt(x) - sin(x))
33+
end
34+
35+
@noinline function neg_sin!()
36+
x = continuous()
37+
always!(sin(x) - ddt(x))
38+
end
39+
3040
function twocall!()
3141
onecall!(); onecall!();
3242
return nothing
3343
end
3444

45+
function sin2!()
46+
sin!(); sin!();
47+
return nothing
48+
end
49+
3550
@testset "Validation" begin
3651
refresh() # TODO: remove before merge
3752

@@ -47,6 +62,18 @@ end
4762
@test residuals [0.0, -3.0, 97.0, 13.0]
4863
@test residuals expanded_residuals
4964

65+
u = [2.0]
66+
du = [3.0]
67+
residuals, expanded_residuals = compute_residual_vectors(sin!, u, du; t = 1.0)
68+
@test residuals du .- sin.(u)
69+
@test residuals expanded_residuals
70+
71+
u = [2.0]
72+
du = [3.0]
73+
residuals, expanded_residuals = compute_residual_vectors(neg_sin!, u, du; t = 1.0)
74+
@test residuals sin.(u) .- du
75+
@test residuals expanded_residuals
76+
5077
# IPO
5178

5279
u = [2.0]
@@ -60,4 +87,10 @@ end
6087
residuals, expanded_residuals = compute_residual_vectors(twocall!, u, du; t = 1.0)
6188
@test residuals [1.0, 3.0]
6289
@test residuals expanded_residuals
90+
91+
u = [2.0, 4.0]
92+
du = [1.0, 1.0]
93+
residuals, expanded_residuals = compute_residual_vectors(sin2!, u, du; t = 1.0)
94+
@test all(>(0), residuals)
95+
@test residuals expanded_residuals
6396
end;

0 commit comments

Comments
 (0)