@@ -2,12 +2,11 @@ function expand_residuals(state::TransformationState, key::TornCacheKey, compres
22 (; result, structure) = state
33 expanded = Float64[]
44 i = 1
5- # TODO : Remove this and the `key` argument if unused
6- # var_eq_matching = matching_for_key(state, key)
7- # (slot_assignments, var_assignment, _) = assign_slots(state, key, var_eq_matching)
5+ var_eq_matching = matching_for_key (state, key)
86 for (eq, incidence) in enumerate (result. total_incidence)
97 if ! is_var_part_known_linear (incidence)
10- push! (expanded, compressed[i])
8+ sign = infer_residual_sign (result, eq, var_eq_matching)
9+ push! (expanded, sign * compressed[i])
1110 i += 1
1211 continue
1312 end
@@ -21,7 +20,7 @@ function expand_residuals(state::TransformationState, key::TornCacheKey, compres
2120 is_diff = is_differential_variable (structure, var)
2221 source = ifelse (is_diff, du, u)
2322 slot = count (v -> is_differential_variable (structure, v) == is_diff, 1 : var)
24- eq == var && ! is_diff && (i += 1 )
23+ var === invview (var_eq_matching)[eq] && ! is_diff && (i += 1 )
2524 value = source[slot]
2625 end
2726 residual += value * coeff
@@ -32,6 +31,22 @@ function expand_residuals(state::TransformationState, key::TornCacheKey, compres
3231 return expanded
3332end
3433
34+ function infer_residual_sign (result:: DAEIPOResult , eq:: Int , var_eq_matching)
35+ # If a linear solved term appears with a positive coefficient,
36+ # the residual will be taken as the negative of the value provided to `always!`.
37+ # For example: ẋ₁ - x₁x₂ = 0
38+ # -ẋ₁ = -x₁x₂
39+ # ẋ₁ = -x₁x₂/-1
40+ # ẋ₁ = x₁x₂
41+ # 0 = x₁x₂ - ẋ₁ <-- residual
42+ incidence = result. total_incidence[eq]
43+ var = invview (var_eq_matching)[eq]
44+ isa (var, Int) || return 1
45+ coeff = incidence. row[var + 1 ]
46+ isa (coeff, Float64) || return - 1
47+ return - sign (coeff)
48+ end
49+
3550function is_differential_variable (structure:: DAESystemStructure , var)
3651 structure. var_to_diff[var] != = nothing && return false
3752 return invview (structure. var_to_diff)[var] != = nothing && return true
@@ -78,7 +93,7 @@ derivative may differ between the unoptimized and optimized versions.
7893"""
7994function compute_residual_vectors (f, u, du; t = rand (), mode= DAE, world= Base. tls_world_age ())
8095 @assert mode === DAE # TODO : support ODEs
81- settings = Settings (; mode)
96+ settings = Settings (; mode, insert_stmt_debuginfo = true )
8297 ci = _code_ad_by_type (Tuple{typeof (f)}; world)
8398 result = @code_structure result= true mode= mode world= world f ()
8499 structure = make_structure_from_ipo (result)
0 commit comments