Skip to content

Commit 5386832

Browse files
authored
Merge pull request #3034 from SciML/myb/unit_fix
Add unit aware diff2term
2 parents be151e3 + 9deb045 commit 5386832

File tree

8 files changed

+50
-8
lines changed

8 files changed

+50
-8
lines changed

src/structural_transformation/StructuralTransformations.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using SymbolicUtils: maketerm, iscall
1111

1212
using ModelingToolkit
1313
using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Differential,
14-
unknowns, equations, vars, Symbolic, diff2term, value,
14+
unknowns, equations, vars, Symbolic, diff2term_with_unit, value,
1515
operation, arguments, Sym, Term, simplify, symbolic_linear_solve,
1616
isdiffeq, isdifferential, isirreducible,
1717
empty_substitutions, get_substitutions,
@@ -22,7 +22,7 @@ using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Di
2222
get_postprocess_fbody, vars!,
2323
IncrementalCycleTracker, add_edge_checked!, topological_sort,
2424
invalidate_cache!, Substitutions, get_or_construct_tearing_state,
25-
filter_kwargs, lower_varname, setio, SparseMatrixCLIL,
25+
filter_kwargs, lower_varname_with_unit, setio, SparseMatrixCLIL,
2626
get_fullvars, has_equations, observed,
2727
Schedule
2828

src/structural_transformation/pantelides.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ function pantelides_reassemble(state::TearingState, var_eq_matching)
1616
fill!(out_vars, nothing)
1717
out_vars[1:length(fullvars)] .= fullvars
1818

19-
D = Differential(get_iv(sys))
19+
iv = get_iv(sys)
20+
D = Differential(iv)
2021

2122
for (varidx, diff) in edges(var_to_diff)
2223
# fullvars[diff] = D(fullvars[var])
@@ -25,7 +26,7 @@ function pantelides_reassemble(state::TearingState, var_eq_matching)
2526
# `fullvars[i]` needs to be not a `D(...)`, because we want the DAE to be
2627
# first-order.
2728
if isdifferential(vi)
28-
vi = out_vars[varidx] = diff2term(vi)
29+
vi = out_vars[varidx] = diff2term_with_unit(vi, iv)
2930
end
3031
out_vars[diff] = D(vi)
3132
end

src/structural_transformation/symbolics_tearing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
274274
dv === nothing && continue
275275
if var_eq_matching[var] !== SelectedState()
276276
dd = fullvars[dv]
277-
v_t = setio(diff2term(unwrap(dd)), false, false)
277+
v_t = setio(diff2term_with_unit(unwrap(dd), unwrap(iv)), false, false)
278278
for eq in 𝑑neighbors(graph, dv)
279279
dummy_sub[dd] = v_t
280280
neweqs[eq] = fast_substitute(neweqs[eq], dd => v_t)

src/structural_transformation/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ function lower_varname_withshift(var, iv, order)
426426
op = operation(var)
427427
return Shift(op.t, order)(var)
428428
end
429-
return lower_varname(var, iv, order)
429+
return lower_varname_with_unit(var, iv, order)
430430
end
431431

432432
function isdoubleshift(var)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ end
456456

457457
function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
458458
ps = parameters(sys), u0 = nothing;
459-
ddvs = map(diff2term Differential(get_iv(sys)), dvs),
459+
ddvs = map(Base.Fix2(diff2term, get_iv(sys)) Differential(get_iv(sys)), dvs),
460460
version = nothing, p = nothing,
461461
jac = false,
462462
eval_expression = false,

src/systems/diffeqs/first_order_transform.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ function ode_order_lowering(eqs, iv, unknown_vars)
3434
var, maxorder = var_from_nested_derivative(eq.lhs)
3535
maxorder > get(var_order, var, 1) && (var_order[var] = maxorder)
3636
var′ = lower_varname(var, iv, maxorder - 1)
37-
rhs′ = diff2term(eq.rhs)
37+
rhs′ = diff2term_with_unit(eq.rhs, iv)
3838
push!(diff_vars, var′)
3939
push!(diff_eqs, D(var′) ~ rhs′)
4040
end

src/utils.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -859,3 +859,16 @@ function eval_or_rgf(expr::Expr; eval_expression = false, eval_module = @__MODUL
859859
return drop_expr(RuntimeGeneratedFunction(eval_module, eval_module, expr))
860860
end
861861
end
862+
863+
function _with_unit(f, x, t, args...)
864+
x = f(x, args...)
865+
if hasmetadata(x, VariableUnit) && (t isa Symbolic && hasmetadata(t, VariableUnit))
866+
xu = getmetadata(x, VariableUnit)
867+
tu = getmetadata(t, VariableUnit)
868+
x = setmetadata(x, VariableUnit, xu / tu)
869+
end
870+
return x
871+
end
872+
873+
diff2term_with_unit(x, t) = _with_unit(diff2term, x, t)
874+
lower_varname_with_unit(var, iv, order) = _with_unit(lower_varname, var, iv, iv, order)

test/dq_units.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,3 +246,31 @@ let
246246
@test MT.get_unit(x_vec) == u"1"
247247
@test MT.get_unit(x_mat) == u"1"
248248
end
249+
250+
module UnitTD
251+
using Test
252+
using ModelingToolkit
253+
using ModelingToolkit: t, D
254+
using DynamicQuantities
255+
256+
@mtkmodel UnitsExample begin
257+
@parameters begin
258+
g, [unit = u"m/s^2"]
259+
L = 1.0, [unit = u"m"]
260+
end
261+
@variables begin
262+
x(t), [unit = u"m"]
263+
y(t), [state_priority = 10, unit = u"m"]
264+
λ(t), [unit = u"s^-2"]
265+
end
266+
@equations begin
267+
D(D(x)) ~ λ * x
268+
D(D(y)) ~ λ * y - g
269+
x^2 + y^2 ~ L^2
270+
end
271+
end
272+
273+
@mtkbuild pend = UnitsExample()
274+
@test ModelingToolkit.get_unit.(filter(x -> occursin("ˍt", string(x)), unknowns(pend))) ==
275+
[u"m/s", u"m/s"]
276+
end

0 commit comments

Comments
 (0)