Skip to content

Commit abaad91

Browse files
Refactor Unitful.jl usage to use package extensions
This commit moves all Unitful.jl-specific functionality into a ModelingToolkitUnitfulExt extension to reduce required dependencies and improve loading times. Major changes: - Move Unitful from dependencies to weakdeps in Project.toml - Create ModelingToolkitUnitfulExt extension with all Unitful-specific functionality - Remove UnitfulUnitCheck module from main codebase, moved to extension - Update unit checking functions to be extensible - Remove direct Unitful imports from main module - Add extensible error handling for dimension errors The extension provides backward compatibility by recreating the UnitfulUnitCheck module when Unitful is loaded. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent 647c9f9 commit abaad91

File tree

6 files changed

+367
-309
lines changed

6 files changed

+367
-309
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
6464
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
6565
URIs = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4"
6666
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
67-
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
6867

6968
[weakdeps]
7069
BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
@@ -74,6 +73,7 @@ FMI = "14a09403-18e3-468f-ad8a-74f8dda2d9ac"
7473
InfiniteOpt = "20393b10-9daf-11e9-18c9-8db751c92c57"
7574
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
7675
Pyomo = "0e8e1daf-01b5-4eba-a626-3897743a3816"
76+
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
7777

7878
[extensions]
7979
MTKBifurcationKitExt = "BifurcationKit"
@@ -83,6 +83,7 @@ MTKFMIExt = "FMI"
8383
MTKInfiniteOptExt = "InfiniteOpt"
8484
MTKLabelledArraysExt = "LabelledArrays"
8585
MTKPyomoDynamicOptExt = "Pyomo"
86+
ModelingToolkitUnitfulExt = "Unitful"
8687

8788
[compat]
8889
ADTypes = "1.14.0"
@@ -165,7 +166,6 @@ SymbolicUtils = "3.26.1"
165166
Symbolics = "6.40"
166167
URIs = "1"
167168
UnPack = "0.1, 1.0"
168-
Unitful = "1.1"
169169
julia = "1.9"
170170

171171
[extras]

ext/ModelingToolkitUnitfulExt.jl

Lines changed: 345 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,345 @@
1+
module ModelingToolkitUnitfulExt
2+
3+
__precompile__(false)
4+
5+
using ModelingToolkit
6+
using Unitful
7+
using Symbolics: Symbolic, value, issym, isadd, ismul, ispow, arguments, operation, iscall, getmetadata
8+
using SciMLBase
9+
using RecursiveArrayTools
10+
using JumpProcesses: MassActionJump, ConstantRateJump, VariableRateJump
11+
12+
# Import necessary types and functions from ModelingToolkit
13+
import ModelingToolkit: ValidationError, Connection, instream, JumpType, VariableUnit,
14+
get_systems, Conditional, Comparison, Differential,
15+
Integral, Num, check_units
16+
17+
const MT = ModelingToolkit
18+
19+
# Method extension for Unitful unit detection
20+
# This adds a method for the specific case where we have a Unitful unit
21+
function MT.__get_scalar_unit_type(v)
22+
u = MT.__get_literal_unit(v)
23+
if u isa MT.DQ.AbstractQuantity
24+
return Val(:DynamicQuantities)
25+
elseif u isa Unitful.Unitlike
26+
return Val(:Unitful)
27+
end
28+
return nothing
29+
end
30+
31+
# Base operations for mixing Symbolic and Unitful
32+
Base.:*(x::Union{Num, Symbolic}, y::Unitful.AbstractQuantity) = x * y
33+
Base.:/(x::Union{Num, Symbolic}, y::Unitful.AbstractQuantity) = x / y
34+
35+
"""
36+
Throw exception on invalid unit types, otherwise return argument.
37+
"""
38+
function screen_unit(result)
39+
result isa Unitful.Unitlike ||
40+
throw(ValidationError("Unit must be a subtype of Unitful.Unitlike, not $(typeof(result))."))
41+
result isa Unitful.ScalarUnits ||
42+
throw(ValidationError("Non-scalar units such as $result are not supported. Use a scalar unit instead."))
43+
result == Unitful.u"°" &&
44+
throw(ValidationError("Degrees are not supported. Use radians instead."))
45+
result
46+
end
47+
48+
"""
49+
Test unit equivalence.
50+
51+
Example of implemented behavior:
52+
53+
```julia
54+
using ModelingToolkit, Unitful
55+
MT = ModelingToolkit
56+
@parameters γ P [unit = u"MW"] E [unit = u"kJ"] τ [unit = u"ms"]
57+
@test MT.equivalent(u"MW", u"kJ/ms") # Understands prefixes
58+
@test !MT.equivalent(u"m", u"cm") # Units must be same magnitude
59+
@test MT.equivalent(MT.get_unit(P^γ), MT.get_unit((E / τ)^γ)) # Handles symbolic exponents
60+
```
61+
"""
62+
equivalent(x, y) = isequal(1 * x, 1 * y)
63+
const unitless = Unitful.unit(1)
64+
65+
"""
66+
Find the unit of a symbolic item.
67+
"""
68+
get_unit(x::Real) = unitless
69+
get_unit(x::Unitful.Quantity) = screen_unit(Unitful.unit(x))
70+
get_unit(x::AbstractArray) = map(get_unit, x)
71+
get_unit(x::Num) = get_unit(value(x))
72+
function get_unit(x::Union{Symbolics.ArrayOp, Symbolics.Arr, Symbolics.CallWithMetadata})
73+
get_literal_unit(x)
74+
end
75+
get_unit(op::Differential, args) = get_unit(args[1]) / get_unit(op.x)
76+
get_unit(op::typeof(getindex), args) = get_unit(args[1])
77+
get_unit(x::SciMLBase.NullParameters) = unitless
78+
get_unit(op::typeof(instream), args) = get_unit(args[1])
79+
80+
get_literal_unit(x) = screen_unit(getmetadata(x, VariableUnit, unitless))
81+
82+
function get_unit(op, args) # Fallback
83+
result = op(1 .* get_unit.(args)...)
84+
try
85+
Unitful.unit(result)
86+
catch
87+
throw(ValidationError("Unable to get unit for operation $op with arguments $args."))
88+
end
89+
end
90+
91+
function get_unit(op::Integral, args)
92+
unit = 1
93+
if op.domain.variables isa Vector
94+
for u in op.domain.variables
95+
unit *= get_unit(u)
96+
end
97+
else
98+
unit *= get_unit(op.domain.variables)
99+
end
100+
return get_unit(args[1]) * unit
101+
end
102+
103+
function get_unit(op::Conditional, args)
104+
terms = get_unit.(args)
105+
terms[1] == unitless ||
106+
throw(ValidationError(", in $op, [$(terms[1])] is not dimensionless."))
107+
equivalent(terms[2], terms[3]) ||
108+
throw(ValidationError(", in $op, units [$(terms[2])] and [$(terms[3])] do not match."))
109+
return terms[2]
110+
end
111+
112+
function get_unit(op::typeof(Symbolics._mapreduce), args)
113+
if args[2] == +
114+
get_unit(args[3])
115+
else
116+
throw(ValidationError("Unsupported array operation $op"))
117+
end
118+
end
119+
120+
function get_unit(op::Comparison, args)
121+
terms = get_unit.(args)
122+
equivalent(terms[1], terms[2]) ||
123+
throw(ValidationError(", in comparison $op, units [$(terms[1])] and [$(terms[2])] do not match."))
124+
return unitless
125+
end
126+
127+
function get_unit(x::Symbolic)
128+
if issym(x)
129+
get_literal_unit(x)
130+
elseif isadd(x)
131+
terms = get_unit.(arguments(x))
132+
firstunit = terms[1]
133+
for other in terms[2:end]
134+
termlist = join(map(repr, terms), ", ")
135+
equivalent(other, firstunit) ||
136+
throw(ValidationError(", in sum $x, units [$termlist] do not match."))
137+
end
138+
return firstunit
139+
elseif ispow(x)
140+
pargs = arguments(x)
141+
base, expon = get_unit.(pargs)
142+
@assert expon isa Unitful.DimensionlessUnits
143+
if base == unitless
144+
unitless
145+
else
146+
pargs[2] isa Number ? base^pargs[2] : (1 * base)^pargs[2]
147+
end
148+
elseif iscall(x)
149+
op = operation(x)
150+
if issym(op) || (iscall(op) && iscall(operation(op))) # Dependent variables, not function calls
151+
return screen_unit(getmetadata(x, VariableUnit, unitless)) # Like x(t) or x[i]
152+
elseif iscall(op) && !iscall(operation(op))
153+
gp = getmetadata(x, Symbolics.GetindexParent, nothing) # Like x[1](t)
154+
return screen_unit(getmetadata(gp, VariableUnit, unitless))
155+
end # Actual function calls:
156+
args = arguments(x)
157+
return get_unit(op, args)
158+
else # This function should only be reached by Terms, for which `iscall` is true
159+
throw(ArgumentError("Unsupported value $x."))
160+
end
161+
end
162+
163+
"""
164+
Get unit of term, returning nothing & showing warning instead of throwing errors.
165+
"""
166+
function safe_get_unit(term, info)
167+
side = nothing
168+
try
169+
side = get_unit(term)
170+
catch err
171+
if err isa Unitful.DimensionError
172+
@warn("$info: $(err.x) and $(err.y) are not dimensionally compatible.")
173+
elseif err isa ValidationError
174+
@warn(info*err.message)
175+
elseif err isa MethodError
176+
@warn("$info: no method matching $(err.f) for arguments $(typeof.(err.args)).")
177+
else
178+
rethrow()
179+
end
180+
end
181+
side
182+
end
183+
184+
function _validate(terms::Vector, labels::Vector{String}; info::String = "")
185+
valid = true
186+
first_unit = nothing
187+
first_label = nothing
188+
for (term, label) in zip(terms, labels)
189+
equnit = safe_get_unit(term, info * label)
190+
if equnit === nothing
191+
valid = false
192+
elseif !isequal(term, 0)
193+
if first_unit === nothing
194+
first_unit = equnit
195+
first_label = label
196+
elseif !equivalent(first_unit, equnit)
197+
valid = false
198+
@warn("$info: units [$(first_unit)] for $(first_label) and [$(equnit)] for $(label) do not match.")
199+
end
200+
end
201+
end
202+
valid
203+
end
204+
205+
function _validate(conn::Connection; info::String = "")
206+
valid = true
207+
syss = get_systems(conn)
208+
sys = first(syss)
209+
unks = MT.unknowns(sys)
210+
for i in 2:length(syss)
211+
s = syss[i]
212+
_unks = MT.unknowns(s)
213+
if length(unks) != length(_unks)
214+
valid = false
215+
@warn("$info: connected systems $(MT.nameof(sys)) and $(MT.nameof(s)) have $(length(unks)) and $(length(_unks)) unknowns, cannot connect.")
216+
continue
217+
end
218+
for (i, x) in enumerate(unks)
219+
j = findfirst(isequal(x), _unks)
220+
if j == nothing
221+
valid = false
222+
@warn("$info: connected systems $(MT.nameof(sys)) and $(MT.nameof(s)) do not have the same unknowns.")
223+
else
224+
aunit = safe_get_unit(x, info * string(MT.nameof(sys)) * "#$i")
225+
bunit = safe_get_unit(_unks[j], info * string(MT.nameof(s)) * "#$j")
226+
if !equivalent(aunit, bunit)
227+
valid = false
228+
@warn("$info: connected system unknowns $x and $(_unks[j]) have mismatched units.")
229+
end
230+
end
231+
end
232+
end
233+
valid
234+
end
235+
236+
function validate(jump::Union{VariableRateJump, ConstantRateJump}, t::Symbolic; info::String = "")
237+
newinfo = replace(info, "eq." => "jump")
238+
_validate([jump.rate, 1 / t], ["rate", "1/t"], info = newinfo) && # Assuming the rate is per time units
239+
validate(jump.affect!, info = newinfo)
240+
end
241+
242+
function validate(jump::MassActionJump, t::Symbolic; info::String = "")
243+
left_symbols = [x[1] for x in jump.reactant_stoch] #vector of pairs of symbol,int -> vector symbols
244+
net_symbols = [x[1] for x in jump.net_stoch]
245+
all_symbols = vcat(left_symbols, net_symbols)
246+
allgood = _validate(all_symbols, string.(all_symbols); info)
247+
n = sum(x -> x[2], jump.reactant_stoch, init = 0)
248+
base_unitful = all_symbols[1] #all same, get first
249+
allgood && _validate([jump.scaled_rates, 1 / (t * base_unitful^n)],
250+
["scaled_rates", "1/(t*reactants^$n))"]; info)
251+
end
252+
253+
function validate(jumps::Vector{JumpType}, t::Symbolic)
254+
labels = ["in Mass Action Jumps,", "in Constant Rate Jumps,", "in Variable Rate Jumps,"]
255+
majs = filter(x -> x isa MassActionJump, jumps)
256+
crjs = filter(x -> x isa ConstantRateJump, jumps)
257+
vrjs = filter(x -> x isa VariableRateJump, jumps)
258+
splitjumps = [majs, crjs, vrjs]
259+
all([validate(js, t; info) for (js, info) in zip(splitjumps, labels)])
260+
end
261+
262+
function validate(eq::MT.Equation; info::String = "")
263+
if typeof(eq.lhs) == Connection
264+
_validate(eq.rhs; info)
265+
else
266+
_validate([eq.lhs, eq.rhs], ["left", "right"]; info)
267+
end
268+
end
269+
270+
function validate(eq::MT.Equation, term::Union{Symbolic, Unitful.Quantity, Num}; info::String = "")
271+
_validate([eq.lhs, eq.rhs, term], ["left", "right", "noise"]; info)
272+
end
273+
274+
function validate(eq::MT.Equation, terms::Vector; info::String = "")
275+
_validate(vcat([eq.lhs, eq.rhs], terms),
276+
vcat(["left", "right"], "noise #" .* string.(1:length(terms))); info)
277+
end
278+
279+
"""
280+
Returns true iff units of equations are valid.
281+
"""
282+
function validate(eqs::Vector; info::String = "")
283+
all([validate(eqs[idx], info = info * " in eq. #$idx") for idx in 1:length(eqs)])
284+
end
285+
286+
function validate(eqs::Vector, noise::Vector; info::String = "")
287+
all([validate(eqs[idx], noise[idx], info = info * " in eq. #$idx")
288+
for idx in 1:length(eqs)])
289+
end
290+
291+
function validate(eqs::Vector, noise::Matrix; info::String = "")
292+
all([validate(eqs[idx], noise[idx, :], info = info * " in eq. #$idx")
293+
for idx in 1:length(eqs)])
294+
end
295+
296+
function validate(eqs::Vector, term::Symbolic; info::String = "")
297+
all([validate(eqs[idx], term, info = info * " in eq. #$idx") for idx in 1:length(eqs)])
298+
end
299+
300+
validate(term::Symbolic) = safe_get_unit(term, "") !== nothing
301+
302+
"""
303+
Throws error if units of equations are invalid.
304+
"""
305+
function check_units(::Val{:Unitful}, eqs...)
306+
validate(eqs...) ||
307+
throw(ValidationError("Some equations had invalid units. See warnings for details."))
308+
end
309+
310+
# Model parsing functions for Unitful
311+
function convert_units(varunits::Unitful.FreeUnits, value)
312+
Unitful.ustrip(varunits, value)
313+
end
314+
315+
convert_units(::Unitful.FreeUnits, value::MT.NoValue) = MT.NO_VALUE
316+
317+
function convert_units(varunits::Unitful.FreeUnits, value::AbstractArray{T}) where {T}
318+
Unitful.ustrip.(varunits, value)
319+
end
320+
321+
convert_units(::Unitful.FreeUnits, value::Num) = value
322+
323+
# Extend model parsing error handling to include Unitful.DimensionError
324+
MT._is_dimension_error(e::Unitful.DimensionError) = true
325+
326+
# Define Unitful time variables (moved from main module)
327+
const t_unitful = let
328+
MT.only(MT.@independent_variables t [unit = Unitful.u"s"])
329+
end
330+
const D_unitful = MT.Differential(t_unitful)
331+
332+
# Create a UnitfulUnitCheck module for backward compatibility
333+
module UnitfulUnitCheck
334+
using ..ModelingToolkitUnitfulExt
335+
# Re-export all functions from the extension for backward compatibility
336+
const equivalent = ModelingToolkitUnitfulExt.equivalent
337+
const unitless = ModelingToolkitUnitfulExt.unitless
338+
const get_unit = ModelingToolkitUnitfulExt.get_unit
339+
const get_literal_unit = ModelingToolkitUnitfulExt.get_literal_unit
340+
const safe_get_unit = ModelingToolkitUnitfulExt.safe_get_unit
341+
const validate = ModelingToolkitUnitfulExt.validate
342+
const screen_unit = ModelingToolkitUnitfulExt.screen_unit
343+
end
344+
345+
end # module

0 commit comments

Comments
 (0)