1+ module ModelingToolkitUnitfulExt
2+
3+ __precompile__ (false )
4+
5+ using ModelingToolkit
6+ using Unitful
7+ using Symbolics: Symbolic, value, issym, isadd, ismul, ispow, arguments, operation, iscall, getmetadata
8+ using SciMLBase
9+ using RecursiveArrayTools
10+ using JumpProcesses: MassActionJump, ConstantRateJump, VariableRateJump
11+
12+ # Import necessary types and functions from ModelingToolkit
13+ import ModelingToolkit: ValidationError, Connection, instream, JumpType, VariableUnit,
14+ get_systems, Conditional, Comparison, Differential,
15+ Integral, Num, check_units
16+
17+ const MT = ModelingToolkit
18+
19+ # Method extension for Unitful unit detection
20+ # This adds a method for the specific case where we have a Unitful unit
21+ function MT. __get_scalar_unit_type (v)
22+ u = MT. __get_literal_unit (v)
23+ if u isa MT. DQ. AbstractQuantity
24+ return Val (:DynamicQuantities )
25+ elseif u isa Unitful. Unitlike
26+ return Val (:Unitful )
27+ end
28+ return nothing
29+ end
30+
31+ # Base operations for mixing Symbolic and Unitful
32+ Base.:* (x:: Union{Num, Symbolic} , y:: Unitful.AbstractQuantity ) = x * y
33+ Base.:/ (x:: Union{Num, Symbolic} , y:: Unitful.AbstractQuantity ) = x / y
34+
35+ """
36+ Throw exception on invalid unit types, otherwise return argument.
37+ """
38+ function screen_unit (result)
39+ result isa Unitful. Unitlike ||
40+ throw (ValidationError (" Unit must be a subtype of Unitful.Unitlike, not $(typeof (result)) ." ))
41+ result isa Unitful. ScalarUnits ||
42+ throw (ValidationError (" Non-scalar units such as $result are not supported. Use a scalar unit instead." ))
43+ result == Unitful. u " °" &&
44+ throw (ValidationError (" Degrees are not supported. Use radians instead." ))
45+ result
46+ end
47+
48+ """
49+ Test unit equivalence.
50+
51+ Example of implemented behavior:
52+
53+ ```julia
54+ using ModelingToolkit, Unitful
55+ MT = ModelingToolkit
56+ @parameters γ P [unit = u"MW"] E [unit = u"kJ"] τ [unit = u"ms"]
57+ @test MT.equivalent(u"MW", u"kJ/ms") # Understands prefixes
58+ @test !MT.equivalent(u"m", u"cm") # Units must be same magnitude
59+ @test MT.equivalent(MT.get_unit(P^γ), MT.get_unit((E / τ)^γ)) # Handles symbolic exponents
60+ ```
61+ """
62+ equivalent (x, y) = isequal (1 * x, 1 * y)
63+ const unitless = Unitful. unit (1 )
64+
65+ """
66+ Find the unit of a symbolic item.
67+ """
68+ get_unit (x:: Real ) = unitless
69+ get_unit (x:: Unitful.Quantity ) = screen_unit (Unitful. unit (x))
70+ get_unit (x:: AbstractArray ) = map (get_unit, x)
71+ get_unit (x:: Num ) = get_unit (value (x))
72+ function get_unit (x:: Union{Symbolics.ArrayOp, Symbolics.Arr, Symbolics.CallWithMetadata} )
73+ get_literal_unit (x)
74+ end
75+ get_unit (op:: Differential , args) = get_unit (args[1 ]) / get_unit (op. x)
76+ get_unit (op:: typeof (getindex), args) = get_unit (args[1 ])
77+ get_unit (x:: SciMLBase.NullParameters ) = unitless
78+ get_unit (op:: typeof (instream), args) = get_unit (args[1 ])
79+
80+ get_literal_unit (x) = screen_unit (getmetadata (x, VariableUnit, unitless))
81+
82+ function get_unit (op, args) # Fallback
83+ result = op (1 .* get_unit .(args)... )
84+ try
85+ Unitful. unit (result)
86+ catch
87+ throw (ValidationError (" Unable to get unit for operation $op with arguments $args ." ))
88+ end
89+ end
90+
91+ function get_unit (op:: Integral , args)
92+ unit = 1
93+ if op. domain. variables isa Vector
94+ for u in op. domain. variables
95+ unit *= get_unit (u)
96+ end
97+ else
98+ unit *= get_unit (op. domain. variables)
99+ end
100+ return get_unit (args[1 ]) * unit
101+ end
102+
103+ function get_unit (op:: Conditional , args)
104+ terms = get_unit .(args)
105+ terms[1 ] == unitless ||
106+ throw (ValidationError (" , in $op , [$(terms[1 ]) ] is not dimensionless." ))
107+ equivalent (terms[2 ], terms[3 ]) ||
108+ throw (ValidationError (" , in $op , units [$(terms[2 ]) ] and [$(terms[3 ]) ] do not match." ))
109+ return terms[2 ]
110+ end
111+
112+ function get_unit (op:: typeof (Symbolics. _mapreduce), args)
113+ if args[2 ] == +
114+ get_unit (args[3 ])
115+ else
116+ throw (ValidationError (" Unsupported array operation $op " ))
117+ end
118+ end
119+
120+ function get_unit (op:: Comparison , args)
121+ terms = get_unit .(args)
122+ equivalent (terms[1 ], terms[2 ]) ||
123+ throw (ValidationError (" , in comparison $op , units [$(terms[1 ]) ] and [$(terms[2 ]) ] do not match." ))
124+ return unitless
125+ end
126+
127+ function get_unit (x:: Symbolic )
128+ if issym (x)
129+ get_literal_unit (x)
130+ elseif isadd (x)
131+ terms = get_unit .(arguments (x))
132+ firstunit = terms[1 ]
133+ for other in terms[2 : end ]
134+ termlist = join (map (repr, terms), " , " )
135+ equivalent (other, firstunit) ||
136+ throw (ValidationError (" , in sum $x , units [$termlist ] do not match." ))
137+ end
138+ return firstunit
139+ elseif ispow (x)
140+ pargs = arguments (x)
141+ base, expon = get_unit .(pargs)
142+ @assert expon isa Unitful. DimensionlessUnits
143+ if base == unitless
144+ unitless
145+ else
146+ pargs[2 ] isa Number ? base^ pargs[2 ] : (1 * base)^ pargs[2 ]
147+ end
148+ elseif iscall (x)
149+ op = operation (x)
150+ if issym (op) || (iscall (op) && iscall (operation (op))) # Dependent variables, not function calls
151+ return screen_unit (getmetadata (x, VariableUnit, unitless)) # Like x(t) or x[i]
152+ elseif iscall (op) && ! iscall (operation (op))
153+ gp = getmetadata (x, Symbolics. GetindexParent, nothing ) # Like x[1](t)
154+ return screen_unit (getmetadata (gp, VariableUnit, unitless))
155+ end # Actual function calls:
156+ args = arguments (x)
157+ return get_unit (op, args)
158+ else # This function should only be reached by Terms, for which `iscall` is true
159+ throw (ArgumentError (" Unsupported value $x ." ))
160+ end
161+ end
162+
163+ """
164+ Get unit of term, returning nothing & showing warning instead of throwing errors.
165+ """
166+ function safe_get_unit (term, info)
167+ side = nothing
168+ try
169+ side = get_unit (term)
170+ catch err
171+ if err isa Unitful. DimensionError
172+ @warn (" $info : $(err. x) and $(err. y) are not dimensionally compatible." )
173+ elseif err isa ValidationError
174+ @warn (info* err. message)
175+ elseif err isa MethodError
176+ @warn (" $info : no method matching $(err. f) for arguments $(typeof .(err. args)) ." )
177+ else
178+ rethrow ()
179+ end
180+ end
181+ side
182+ end
183+
184+ function _validate (terms:: Vector , labels:: Vector{String} ; info:: String = " " )
185+ valid = true
186+ first_unit = nothing
187+ first_label = nothing
188+ for (term, label) in zip (terms, labels)
189+ equnit = safe_get_unit (term, info * label)
190+ if equnit === nothing
191+ valid = false
192+ elseif ! isequal (term, 0 )
193+ if first_unit === nothing
194+ first_unit = equnit
195+ first_label = label
196+ elseif ! equivalent (first_unit, equnit)
197+ valid = false
198+ @warn (" $info : units [$(first_unit) ] for $(first_label) and [$(equnit) ] for $(label) do not match." )
199+ end
200+ end
201+ end
202+ valid
203+ end
204+
205+ function _validate (conn:: Connection ; info:: String = " " )
206+ valid = true
207+ syss = get_systems (conn)
208+ sys = first (syss)
209+ unks = MT. unknowns (sys)
210+ for i in 2 : length (syss)
211+ s = syss[i]
212+ _unks = MT. unknowns (s)
213+ if length (unks) != length (_unks)
214+ valid = false
215+ @warn (" $info : connected systems $(MT. nameof (sys)) and $(MT. nameof (s)) have $(length (unks)) and $(length (_unks)) unknowns, cannot connect." )
216+ continue
217+ end
218+ for (i, x) in enumerate (unks)
219+ j = findfirst (isequal (x), _unks)
220+ if j == nothing
221+ valid = false
222+ @warn (" $info : connected systems $(MT. nameof (sys)) and $(MT. nameof (s)) do not have the same unknowns." )
223+ else
224+ aunit = safe_get_unit (x, info * string (MT. nameof (sys)) * " #$i " )
225+ bunit = safe_get_unit (_unks[j], info * string (MT. nameof (s)) * " #$j " )
226+ if ! equivalent (aunit, bunit)
227+ valid = false
228+ @warn (" $info : connected system unknowns $x and $(_unks[j]) have mismatched units." )
229+ end
230+ end
231+ end
232+ end
233+ valid
234+ end
235+
236+ function validate (jump:: Union{VariableRateJump, ConstantRateJump} , t:: Symbolic ; info:: String = " " )
237+ newinfo = replace (info, " eq." => " jump" )
238+ _validate ([jump. rate, 1 / t], [" rate" , " 1/t" ], info = newinfo) && # Assuming the rate is per time units
239+ validate (jump. affect!, info = newinfo)
240+ end
241+
242+ function validate (jump:: MassActionJump , t:: Symbolic ; info:: String = " " )
243+ left_symbols = [x[1 ] for x in jump. reactant_stoch] # vector of pairs of symbol,int -> vector symbols
244+ net_symbols = [x[1 ] for x in jump. net_stoch]
245+ all_symbols = vcat (left_symbols, net_symbols)
246+ allgood = _validate (all_symbols, string .(all_symbols); info)
247+ n = sum (x -> x[2 ], jump. reactant_stoch, init = 0 )
248+ base_unitful = all_symbols[1 ] # all same, get first
249+ allgood && _validate ([jump. scaled_rates, 1 / (t * base_unitful^ n)],
250+ [" scaled_rates" , " 1/(t*reactants^$n ))" ]; info)
251+ end
252+
253+ function validate (jumps:: Vector{JumpType} , t:: Symbolic )
254+ labels = [" in Mass Action Jumps," , " in Constant Rate Jumps," , " in Variable Rate Jumps," ]
255+ majs = filter (x -> x isa MassActionJump, jumps)
256+ crjs = filter (x -> x isa ConstantRateJump, jumps)
257+ vrjs = filter (x -> x isa VariableRateJump, jumps)
258+ splitjumps = [majs, crjs, vrjs]
259+ all ([validate (js, t; info) for (js, info) in zip (splitjumps, labels)])
260+ end
261+
262+ function validate (eq:: MT.Equation ; info:: String = " " )
263+ if typeof (eq. lhs) == Connection
264+ _validate (eq. rhs; info)
265+ else
266+ _validate ([eq. lhs, eq. rhs], [" left" , " right" ]; info)
267+ end
268+ end
269+
270+ function validate (eq:: MT.Equation , term:: Union{Symbolic, Unitful.Quantity, Num} ; info:: String = " " )
271+ _validate ([eq. lhs, eq. rhs, term], [" left" , " right" , " noise" ]; info)
272+ end
273+
274+ function validate (eq:: MT.Equation , terms:: Vector ; info:: String = " " )
275+ _validate (vcat ([eq. lhs, eq. rhs], terms),
276+ vcat ([" left" , " right" ], " noise #" .* string .(1 : length (terms))); info)
277+ end
278+
279+ """
280+ Returns true iff units of equations are valid.
281+ """
282+ function validate (eqs:: Vector ; info:: String = " " )
283+ all ([validate (eqs[idx], info = info * " in eq. #$idx " ) for idx in 1 : length (eqs)])
284+ end
285+
286+ function validate (eqs:: Vector , noise:: Vector ; info:: String = " " )
287+ all ([validate (eqs[idx], noise[idx], info = info * " in eq. #$idx " )
288+ for idx in 1 : length (eqs)])
289+ end
290+
291+ function validate (eqs:: Vector , noise:: Matrix ; info:: String = " " )
292+ all ([validate (eqs[idx], noise[idx, :], info = info * " in eq. #$idx " )
293+ for idx in 1 : length (eqs)])
294+ end
295+
296+ function validate (eqs:: Vector , term:: Symbolic ; info:: String = " " )
297+ all ([validate (eqs[idx], term, info = info * " in eq. #$idx " ) for idx in 1 : length (eqs)])
298+ end
299+
300+ validate (term:: Symbolic ) = safe_get_unit (term, " " ) != = nothing
301+
302+ """
303+ Throws error if units of equations are invalid.
304+ """
305+ function check_units (:: Val{:Unitful} , eqs... )
306+ validate (eqs... ) ||
307+ throw (ValidationError (" Some equations had invalid units. See warnings for details." ))
308+ end
309+
310+ # Model parsing functions for Unitful
311+ function convert_units (varunits:: Unitful.FreeUnits , value)
312+ Unitful. ustrip (varunits, value)
313+ end
314+
315+ convert_units (:: Unitful.FreeUnits , value:: MT.NoValue ) = MT. NO_VALUE
316+
317+ function convert_units (varunits:: Unitful.FreeUnits , value:: AbstractArray{T} ) where {T}
318+ Unitful. ustrip .(varunits, value)
319+ end
320+
321+ convert_units (:: Unitful.FreeUnits , value:: Num ) = value
322+
323+ # Extend model parsing error handling to include Unitful.DimensionError
324+ MT. _is_dimension_error (e:: Unitful.DimensionError ) = true
325+
326+ # Define Unitful time variables (moved from main module)
327+ const t_unitful = let
328+ MT. only (MT. @independent_variables t [unit = Unitful. u " s" ])
329+ end
330+ const D_unitful = MT. Differential (t_unitful)
331+
332+ # Create a UnitfulUnitCheck module for backward compatibility
333+ module UnitfulUnitCheck
334+ using .. ModelingToolkitUnitfulExt
335+ # Re-export all functions from the extension for backward compatibility
336+ const equivalent = ModelingToolkitUnitfulExt. equivalent
337+ const unitless = ModelingToolkitUnitfulExt. unitless
338+ const get_unit = ModelingToolkitUnitfulExt. get_unit
339+ const get_literal_unit = ModelingToolkitUnitfulExt. get_literal_unit
340+ const safe_get_unit = ModelingToolkitUnitfulExt. safe_get_unit
341+ const validate = ModelingToolkitUnitfulExt. validate
342+ const screen_unit = ModelingToolkitUnitfulExt. screen_unit
343+ end
344+
345+ end # module
0 commit comments