Skip to content

Commit 90789c0

Browse files
committed
WIP
1 parent 397ab12 commit 90789c0

File tree

7 files changed

+76
-15
lines changed

7 files changed

+76
-15
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1717
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1818
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1919
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
20+
DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821"
2021
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2122
FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf"
2223
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
@@ -71,6 +72,7 @@ DiffRules = "0.1, 1.0"
7172
Distributions = "0.23, 0.24, 0.25"
7273
DocStringExtensions = "0.7, 0.8, 0.9"
7374
DomainSets = "0.6"
75+
DynamicQuantities = "0.8"
7476
ForwardDiff = "0.10.3"
7577
FunctionWrappersWrappers = "0.1"
7678
Graphs = "1.5.2"

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ include("systems/pde/pdesystem.jl")
152152

153153
include("systems/sparsematrixclil.jl")
154154
include("systems/discrete_system/discrete_system.jl")
155+
include("systems/unit_check.jl")
155156
include("systems/validation.jl")
156157
include("systems/dependency_graphs.jl")
157158
include("clock.jl")

src/systems/diffeqs/odesystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,8 @@ struct ODESystem <: AbstractODESystem
163163
check_equations(equations(cevents), iv)
164164
end
165165
if checks == true || (checks & CheckUnits) > 0
166-
all_dimensionless([dvs; ps; iv]) || check_units(deqs)
166+
u = __get_unit_type(dvs, ps, iv)
167+
check_units(u, deqs)
167168
end
168169
new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
169170
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, torn_matching,

src/systems/unit_check.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import DynamicQuantities
2+
const DQ = DynamicQuantities
3+
4+
struct ValidationError <: Exception
5+
message::String
6+
end
7+
8+
check_units(::Nothing, _...) = true
9+
10+
__get_literal_unit(x) = getmetadata(x, VariableUnit, nothing)
11+
function __get_unit_type(vs′...)
12+
vs = Iterators.flatten(vs′)
13+
for v in vs
14+
u = __get_literal_unit(v)
15+
if u isa DQ.AbstractQuantity
16+
return Val(:Unitful)
17+
else
18+
return Val(:DynamicQuantities)
19+
end
20+
end
21+
return nothing
22+
end
23+
24+
function check_units(::Val{:DynamicQuantities}, eqs...)
25+
validate(eqs...) ||
26+
throw(ValidationError("Some equations had invalid units. See warnings for details."))
27+
end
28+
29+
function screen_units(result)
30+
if result isa DQ.AbstractQuantity
31+
d = DQ.dimension(result)
32+
if d isa DQ.Dimensions
33+
return result
34+
elseif d isa DQ.SymbolicDimensions
35+
throw(ValidationError("$result uses SymbolicDimensions, please use `u\"m\"` to instantiate SI unit only."))
36+
else
37+
throw(ValidationError("$result doesn't use SI unit, please use `u\"m\"` to instantiate SI unit only."))
38+
end
39+
end
40+
end
41+
42+

src/systems/validation.jl

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
module UnitfulUnitCheck
2+
3+
using .ModelingToolkit, Symbolics, SciMLBase
4+
using .ModelingToolkit: ValidationError
5+
const MT = ModelingToolkit
6+
17
Base.:*(x::Union{Num, Symbolic}, y::Unitful.AbstractQuantity) = x * y
28
Base.:/(x::Union{Num, Symbolic}, y::Unitful.AbstractQuantity) = x / y
39

4-
struct ValidationError <: Exception
5-
message::String
6-
end
7-
810
"""
911
Throw exception on invalid unit types, otherwise return argument.
1012
"""
@@ -60,7 +62,11 @@ get_literal_unit(x) = screen_unit(getmetadata(x, VariableUnit, unitless))
6062
function get_unit(op, args) # Fallback
6163
result = op(1 .* get_unit.(args)...)
6264
try
63-
unit(result)
65+
if result isa DQ.AbstractQuantity
66+
oneunit(result)
67+
else
68+
unit(result)
69+
end
6470
catch
6571
throw(ValidationError("Unable to get unit for operation $op with arguments $args."))
6672
end
@@ -211,15 +217,15 @@ function _validate(conn::Connection; info::String = "")
211217
valid
212218
end
213219

214-
function validate(jump::Union{ModelingToolkit.VariableRateJump,
215-
ModelingToolkit.ConstantRateJump}, t::Symbolic;
220+
function validate(jump::Union{MT.VariableRateJump,
221+
MT.ConstantRateJump}, t::Symbolic;
216222
info::String = "")
217223
newinfo = replace(info, "eq." => "jump")
218224
_validate([jump.rate, 1 / t], ["rate", "1/t"], info = newinfo) && # Assuming the rate is per time units
219225
validate(jump.affect!, info = newinfo)
220226
end
221227

222-
function validate(jump::ModelingToolkit.MassActionJump, t::Symbolic; info::String = "")
228+
function validate(jump::MT.MassActionJump, t::Symbolic; info::String = "")
223229
left_symbols = [x[1] for x in jump.reactant_stoch] #vector of pairs of symbol,int -> vector symbols
224230
net_symbols = [x[1] for x in jump.net_stoch]
225231
all_symbols = vcat(left_symbols, net_symbols)
@@ -235,18 +241,18 @@ function validate(jumps::ArrayPartition{<:Union{Any, Vector{<:JumpType}}}, t::Sy
235241
all([validate(jumps.x[idx], t, info = labels[idx]) for idx in 1:3])
236242
end
237243

238-
function validate(eq::ModelingToolkit.Equation; info::String = "")
244+
function validate(eq::MT.Equation; info::String = "")
239245
if typeof(eq.lhs) == Connection
240246
_validate(eq.rhs; info)
241247
else
242248
_validate([eq.lhs, eq.rhs], ["left", "right"]; info)
243249
end
244250
end
245-
function validate(eq::ModelingToolkit.Equation,
251+
function validate(eq::MT.Equation,
246252
term::Union{Symbolic, Unitful.Quantity, Num}; info::String = "")
247253
_validate([eq.lhs, eq.rhs, term], ["left", "right", "noise"]; info)
248254
end
249-
function validate(eq::ModelingToolkit.Equation, terms::Vector; info::String = "")
255+
function validate(eq::MT.Equation, terms::Vector; info::String = "")
250256
_validate(vcat([eq.lhs, eq.rhs], terms),
251257
vcat(["left", "right"], "noise #" .* string.(1:length(terms))); info)
252258
end
@@ -273,8 +279,10 @@ validate(term::Symbolics.SymbolicUtils.Symbolic) = safe_get_unit(term, "") !== n
273279
"""
274280
Throws error if units of equations are invalid.
275281
"""
276-
function check_units(eqs...)
282+
function MT.check_units(::Val{:Unitful}, eqs...)
277283
validate(eqs...) ||
278284
throw(ValidationError("Some equations had invalid units. See warnings for details."))
279285
end
280286
all_dimensionless(states) = all(x -> safe_get_unit(x, "") in (unitless, nothing), states)
287+
288+
end # module

test/runtests.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,14 @@ using SafeTestsets, Test
1313
@safetestset "Clock Test" include("clock.jl")
1414
@safetestset "DiscreteSystem Test" include("discretesystem.jl")
1515
@safetestset "ODESystem Test" include("odesystem.jl")
16-
@safetestset "Unitful Quantities Test" include("units.jl")
16+
@safetestset "Dynamic Quantities Test" begin
17+
using DynamicQuantities
18+
include("units.jl")
19+
end
20+
@safetestset "Unitful Quantities Test" begin
21+
using Unitful
22+
include("units.jl")
23+
end
1724
@safetestset "LabelledArrays Test" include("labelledarrays.jl")
1825
@safetestset "Mass Matrix Test" include("mass_matrix.jl")
1926
@safetestset "SteadyStateSystem Test" include("steadystatesystems.jl")

test/units.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ModelingToolkit, Unitful, OrdinaryDiffEq, JumpProcesses, IfElse
1+
using ModelingToolkit, OrdinaryDiffEq, JumpProcesses, IfElse
22
using Test
33
MT = ModelingToolkit
44
@parameters τ [unit = u"ms"] γ

0 commit comments

Comments
 (0)