Skip to content

Commit e676d8f

Browse files
committed
Update lower_varname as well
1 parent 01d2daf commit e676d8f

File tree

7 files changed

+38
-8
lines changed

7 files changed

+38
-8
lines changed

src/structural_transformation/StructuralTransformations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Di
2222
get_postprocess_fbody, vars!,
2323
IncrementalCycleTracker, add_edge_checked!, topological_sort,
2424
invalidate_cache!, Substitutions, get_or_construct_tearing_state,
25-
filter_kwargs, lower_varname, setio, SparseMatrixCLIL,
25+
filter_kwargs, lower_varname_with_unit, setio, SparseMatrixCLIL,
2626
get_fullvars, has_equations, observed,
2727
Schedule
2828

src/structural_transformation/pantelides.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ function pantelides_reassemble(state::TearingState, var_eq_matching)
1616
fill!(out_vars, nothing)
1717
out_vars[1:length(fullvars)] .= fullvars
1818

19-
D = Differential(get_iv(sys))
19+
iv = get_iv(sys)
20+
D = Differential(iv)
2021

2122
for (varidx, diff) in edges(var_to_diff)
2223
# fullvars[diff] = D(fullvars[var])
@@ -25,7 +26,7 @@ function pantelides_reassemble(state::TearingState, var_eq_matching)
2526
# `fullvars[i]` needs to be not a `D(...)`, because we want the DAE to be
2627
# first-order.
2728
if isdifferential(vi)
28-
vi = out_vars[varidx] = diff2term(vi)
29+
vi = out_vars[varidx] = diff2term_with_unit(vi, iv)
2930
end
3031
out_vars[diff] = D(vi)
3132
end

src/structural_transformation/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ function lower_varname_withshift(var, iv, order)
426426
op = operation(var)
427427
return Shift(op.t, order)(var)
428428
end
429-
return lower_varname(var, iv, order)
429+
return lower_varname_with_unit(var, iv, order)
430430
end
431431

432432
function isdoubleshift(var)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ end
456456

457457
function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
458458
ps = parameters(sys), u0 = nothing;
459-
ddvs = map(diff2term Differential(get_iv(sys)), dvs),
459+
ddvs = map(Base.Fix2(diff2term, get_iv(sys)) Differential(get_iv(sys)), dvs),
460460
version = nothing, p = nothing,
461461
jac = false,
462462
eval_expression = false,

src/systems/diffeqs/first_order_transform.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ function ode_order_lowering(eqs, iv, unknown_vars)
3434
var, maxorder = var_from_nested_derivative(eq.lhs)
3535
maxorder > get(var_order, var, 1) && (var_order[var] = maxorder)
3636
var′ = lower_varname(var, iv, maxorder - 1)
37-
rhs′ = diff2term(eq.rhs)
37+
rhs′ = diff2term_with_unit(eq.rhs, iv)
3838
push!(diff_vars, var′)
3939
push!(diff_eqs, D(var′) ~ rhs′)
4040
end

src/utils.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -860,12 +860,15 @@ function eval_or_rgf(expr::Expr; eval_expression = false, eval_module = @__MODUL
860860
end
861861
end
862862

863-
function diff2term_with_unit(x, t)
864-
x = diff2term(x)
863+
function _with_unit(f, x, t, args...)
864+
x = f(x, args...)
865865
if hasmetadata(x, VariableUnit) && (t isa Symbolic && hasmetadata(t, VariableUnit))
866866
xu = getmetadata(x, VariableUnit)
867867
tu = getmetadata(t, VariableUnit)
868868
x = setmetadata(x, VariableUnit, xu / tu)
869869
end
870870
return x
871871
end
872+
873+
diff2term_with_unit(x, t) = _with_unit(diff2term, x, t)
874+
lower_varname_with_unit(var, iv, order) = _with_unit(lower_varname, var, iv, iv, order)

test/dq_units.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,3 +246,29 @@ let
246246
@test MT.get_unit(x_vec) == u"1"
247247
@test MT.get_unit(x_mat) == u"1"
248248
end
249+
250+
module UnitTD
251+
using ModelingToolkit
252+
using ModelingToolkit: t, D
253+
using DynamicQuantities
254+
255+
@mtkmodel UnitsExample begin
256+
@parameters begin
257+
g, [unit = u"m/s^2"]
258+
L = 1.0, [unit = u"m"]
259+
end
260+
@variables begin
261+
x(t), [unit = u"m"]
262+
y(t), [state_priority = 10, unit = u"m"]
263+
λ(t), [unit = u"s^-2"]
264+
end
265+
@equations begin
266+
D(D(x)) ~ λ * x
267+
D(D(y)) ~ λ * y - g
268+
x^2 + y^2 ~ L^2
269+
end
270+
end
271+
272+
@mtkbuild pend = UnitsExample()
273+
@test ModelingToolkit.get_unit.(filter(x -> occursin("ˍt", string(x)), unknowns(pend))) == [u"m/s", u"m/s"]
274+
end

0 commit comments

Comments
 (0)