1
+ module ModelingToolkitUnitfulExt
2
+
3
+ __precompile__ (false )
4
+
5
+ using ModelingToolkit
6
+ using Unitful
7
+ using Symbolics: Symbolic, value, issym, isadd, ismul, ispow, arguments, operation, iscall, getmetadata
8
+ using SciMLBase
9
+ using RecursiveArrayTools
10
+ using JumpProcesses: MassActionJump, ConstantRateJump, VariableRateJump
11
+
12
+ # Import necessary types and functions from ModelingToolkit
13
+ import ModelingToolkit: ValidationError, Connection, instream, JumpType, VariableUnit,
14
+ get_systems, Conditional, Comparison, Differential,
15
+ Integral, Num, check_units
16
+
17
+ const MT = ModelingToolkit
18
+
19
+ # Method extension for Unitful unit detection
20
+ # This adds a method for the specific case where we have a Unitful unit
21
+ function MT. __get_scalar_unit_type (v)
22
+ u = MT. __get_literal_unit (v)
23
+ if u isa MT. DQ. AbstractQuantity
24
+ return Val (:DynamicQuantities )
25
+ elseif u isa Unitful. Unitlike
26
+ return Val (:Unitful )
27
+ end
28
+ return nothing
29
+ end
30
+
31
+ # Base operations for mixing Symbolic and Unitful
32
+ Base.:* (x:: Union{Num, Symbolic} , y:: Unitful.AbstractQuantity ) = x * y
33
+ Base.:/ (x:: Union{Num, Symbolic} , y:: Unitful.AbstractQuantity ) = x / y
34
+
35
+ """
36
+ Throw exception on invalid unit types, otherwise return argument.
37
+ """
38
+ function screen_unit (result)
39
+ result isa Unitful. Unitlike ||
40
+ throw (ValidationError (" Unit must be a subtype of Unitful.Unitlike, not $(typeof (result)) ." ))
41
+ result isa Unitful. ScalarUnits ||
42
+ throw (ValidationError (" Non-scalar units such as $result are not supported. Use a scalar unit instead." ))
43
+ result == Unitful. u " °" &&
44
+ throw (ValidationError (" Degrees are not supported. Use radians instead." ))
45
+ result
46
+ end
47
+
48
+ """
49
+ Test unit equivalence.
50
+
51
+ Example of implemented behavior:
52
+
53
+ ```julia
54
+ using ModelingToolkit, Unitful
55
+ MT = ModelingToolkit
56
+ @parameters γ P [unit = u"MW"] E [unit = u"kJ"] τ [unit = u"ms"]
57
+ @test MT.equivalent(u"MW", u"kJ/ms") # Understands prefixes
58
+ @test !MT.equivalent(u"m", u"cm") # Units must be same magnitude
59
+ @test MT.equivalent(MT.get_unit(P^γ), MT.get_unit((E / τ)^γ)) # Handles symbolic exponents
60
+ ```
61
+ """
62
+ equivalent (x, y) = isequal (1 * x, 1 * y)
63
+ const unitless = Unitful. unit (1 )
64
+
65
+ """
66
+ Find the unit of a symbolic item.
67
+ """
68
+ get_unit (x:: Real ) = unitless
69
+ get_unit (x:: Unitful.Quantity ) = screen_unit (Unitful. unit (x))
70
+ get_unit (x:: AbstractArray ) = map (get_unit, x)
71
+ get_unit (x:: Num ) = get_unit (value (x))
72
+ function get_unit (x:: Union{Symbolics.ArrayOp, Symbolics.Arr, Symbolics.CallWithMetadata} )
73
+ get_literal_unit (x)
74
+ end
75
+ get_unit (op:: Differential , args) = get_unit (args[1 ]) / get_unit (op. x)
76
+ get_unit (op:: typeof (getindex), args) = get_unit (args[1 ])
77
+ get_unit (x:: SciMLBase.NullParameters ) = unitless
78
+ get_unit (op:: typeof (instream), args) = get_unit (args[1 ])
79
+
80
+ get_literal_unit (x) = screen_unit (getmetadata (x, VariableUnit, unitless))
81
+
82
+ function get_unit (op, args) # Fallback
83
+ result = op (1 .* get_unit .(args)... )
84
+ try
85
+ Unitful. unit (result)
86
+ catch
87
+ throw (ValidationError (" Unable to get unit for operation $op with arguments $args ." ))
88
+ end
89
+ end
90
+
91
+ function get_unit (op:: Integral , args)
92
+ unit = 1
93
+ if op. domain. variables isa Vector
94
+ for u in op. domain. variables
95
+ unit *= get_unit (u)
96
+ end
97
+ else
98
+ unit *= get_unit (op. domain. variables)
99
+ end
100
+ return get_unit (args[1 ]) * unit
101
+ end
102
+
103
+ function get_unit (op:: Conditional , args)
104
+ terms = get_unit .(args)
105
+ terms[1 ] == unitless ||
106
+ throw (ValidationError (" , in $op , [$(terms[1 ]) ] is not dimensionless." ))
107
+ equivalent (terms[2 ], terms[3 ]) ||
108
+ throw (ValidationError (" , in $op , units [$(terms[2 ]) ] and [$(terms[3 ]) ] do not match." ))
109
+ return terms[2 ]
110
+ end
111
+
112
+ function get_unit (op:: typeof (Symbolics. _mapreduce), args)
113
+ if args[2 ] == +
114
+ get_unit (args[3 ])
115
+ else
116
+ throw (ValidationError (" Unsupported array operation $op " ))
117
+ end
118
+ end
119
+
120
+ function get_unit (op:: Comparison , args)
121
+ terms = get_unit .(args)
122
+ equivalent (terms[1 ], terms[2 ]) ||
123
+ throw (ValidationError (" , in comparison $op , units [$(terms[1 ]) ] and [$(terms[2 ]) ] do not match." ))
124
+ return unitless
125
+ end
126
+
127
+ function get_unit (x:: Symbolic )
128
+ if issym (x)
129
+ get_literal_unit (x)
130
+ elseif isadd (x)
131
+ terms = get_unit .(arguments (x))
132
+ firstunit = terms[1 ]
133
+ for other in terms[2 : end ]
134
+ termlist = join (map (repr, terms), " , " )
135
+ equivalent (other, firstunit) ||
136
+ throw (ValidationError (" , in sum $x , units [$termlist ] do not match." ))
137
+ end
138
+ return firstunit
139
+ elseif ispow (x)
140
+ pargs = arguments (x)
141
+ base, expon = get_unit .(pargs)
142
+ @assert expon isa Unitful. DimensionlessUnits
143
+ if base == unitless
144
+ unitless
145
+ else
146
+ pargs[2 ] isa Number ? base^ pargs[2 ] : (1 * base)^ pargs[2 ]
147
+ end
148
+ elseif iscall (x)
149
+ op = operation (x)
150
+ if issym (op) || (iscall (op) && iscall (operation (op))) # Dependent variables, not function calls
151
+ return screen_unit (getmetadata (x, VariableUnit, unitless)) # Like x(t) or x[i]
152
+ elseif iscall (op) && ! iscall (operation (op))
153
+ gp = getmetadata (x, Symbolics. GetindexParent, nothing ) # Like x[1](t)
154
+ return screen_unit (getmetadata (gp, VariableUnit, unitless))
155
+ end # Actual function calls:
156
+ args = arguments (x)
157
+ return get_unit (op, args)
158
+ else # This function should only be reached by Terms, for which `iscall` is true
159
+ throw (ArgumentError (" Unsupported value $x ." ))
160
+ end
161
+ end
162
+
163
+ """
164
+ Get unit of term, returning nothing & showing warning instead of throwing errors.
165
+ """
166
+ function safe_get_unit (term, info)
167
+ side = nothing
168
+ try
169
+ side = get_unit (term)
170
+ catch err
171
+ if err isa Unitful. DimensionError
172
+ @warn (" $info : $(err. x) and $(err. y) are not dimensionally compatible." )
173
+ elseif err isa ValidationError
174
+ @warn (info* err. message)
175
+ elseif err isa MethodError
176
+ @warn (" $info : no method matching $(err. f) for arguments $(typeof .(err. args)) ." )
177
+ else
178
+ rethrow ()
179
+ end
180
+ end
181
+ side
182
+ end
183
+
184
+ function _validate (terms:: Vector , labels:: Vector{String} ; info:: String = " " )
185
+ valid = true
186
+ first_unit = nothing
187
+ first_label = nothing
188
+ for (term, label) in zip (terms, labels)
189
+ equnit = safe_get_unit (term, info * label)
190
+ if equnit === nothing
191
+ valid = false
192
+ elseif ! isequal (term, 0 )
193
+ if first_unit === nothing
194
+ first_unit = equnit
195
+ first_label = label
196
+ elseif ! equivalent (first_unit, equnit)
197
+ valid = false
198
+ @warn (" $info : units [$(first_unit) ] for $(first_label) and [$(equnit) ] for $(label) do not match." )
199
+ end
200
+ end
201
+ end
202
+ valid
203
+ end
204
+
205
+ function _validate (conn:: Connection ; info:: String = " " )
206
+ valid = true
207
+ syss = get_systems (conn)
208
+ sys = first (syss)
209
+ unks = MT. unknowns (sys)
210
+ for i in 2 : length (syss)
211
+ s = syss[i]
212
+ _unks = MT. unknowns (s)
213
+ if length (unks) != length (_unks)
214
+ valid = false
215
+ @warn (" $info : connected systems $(MT. nameof (sys)) and $(MT. nameof (s)) have $(length (unks)) and $(length (_unks)) unknowns, cannot connect." )
216
+ continue
217
+ end
218
+ for (i, x) in enumerate (unks)
219
+ j = findfirst (isequal (x), _unks)
220
+ if j == nothing
221
+ valid = false
222
+ @warn (" $info : connected systems $(MT. nameof (sys)) and $(MT. nameof (s)) do not have the same unknowns." )
223
+ else
224
+ aunit = safe_get_unit (x, info * string (MT. nameof (sys)) * " #$i " )
225
+ bunit = safe_get_unit (_unks[j], info * string (MT. nameof (s)) * " #$j " )
226
+ if ! equivalent (aunit, bunit)
227
+ valid = false
228
+ @warn (" $info : connected system unknowns $x and $(_unks[j]) have mismatched units." )
229
+ end
230
+ end
231
+ end
232
+ end
233
+ valid
234
+ end
235
+
236
+ function validate (jump:: Union{VariableRateJump, ConstantRateJump} , t:: Symbolic ; info:: String = " " )
237
+ newinfo = replace (info, " eq." => " jump" )
238
+ _validate ([jump. rate, 1 / t], [" rate" , " 1/t" ], info = newinfo) && # Assuming the rate is per time units
239
+ validate (jump. affect!, info = newinfo)
240
+ end
241
+
242
+ function validate (jump:: MassActionJump , t:: Symbolic ; info:: String = " " )
243
+ left_symbols = [x[1 ] for x in jump. reactant_stoch] # vector of pairs of symbol,int -> vector symbols
244
+ net_symbols = [x[1 ] for x in jump. net_stoch]
245
+ all_symbols = vcat (left_symbols, net_symbols)
246
+ allgood = _validate (all_symbols, string .(all_symbols); info)
247
+ n = sum (x -> x[2 ], jump. reactant_stoch, init = 0 )
248
+ base_unitful = all_symbols[1 ] # all same, get first
249
+ allgood && _validate ([jump. scaled_rates, 1 / (t * base_unitful^ n)],
250
+ [" scaled_rates" , " 1/(t*reactants^$n ))" ]; info)
251
+ end
252
+
253
+ function validate (jumps:: Vector{JumpType} , t:: Symbolic )
254
+ labels = [" in Mass Action Jumps," , " in Constant Rate Jumps," , " in Variable Rate Jumps," ]
255
+ majs = filter (x -> x isa MassActionJump, jumps)
256
+ crjs = filter (x -> x isa ConstantRateJump, jumps)
257
+ vrjs = filter (x -> x isa VariableRateJump, jumps)
258
+ splitjumps = [majs, crjs, vrjs]
259
+ all ([validate (js, t; info) for (js, info) in zip (splitjumps, labels)])
260
+ end
261
+
262
+ function validate (eq:: MT.Equation ; info:: String = " " )
263
+ if typeof (eq. lhs) == Connection
264
+ _validate (eq. rhs; info)
265
+ else
266
+ _validate ([eq. lhs, eq. rhs], [" left" , " right" ]; info)
267
+ end
268
+ end
269
+
270
+ function validate (eq:: MT.Equation , term:: Union{Symbolic, Unitful.Quantity, Num} ; info:: String = " " )
271
+ _validate ([eq. lhs, eq. rhs, term], [" left" , " right" , " noise" ]; info)
272
+ end
273
+
274
+ function validate (eq:: MT.Equation , terms:: Vector ; info:: String = " " )
275
+ _validate (vcat ([eq. lhs, eq. rhs], terms),
276
+ vcat ([" left" , " right" ], " noise #" .* string .(1 : length (terms))); info)
277
+ end
278
+
279
+ """
280
+ Returns true iff units of equations are valid.
281
+ """
282
+ function validate (eqs:: Vector ; info:: String = " " )
283
+ all ([validate (eqs[idx], info = info * " in eq. #$idx " ) for idx in 1 : length (eqs)])
284
+ end
285
+
286
+ function validate (eqs:: Vector , noise:: Vector ; info:: String = " " )
287
+ all ([validate (eqs[idx], noise[idx], info = info * " in eq. #$idx " )
288
+ for idx in 1 : length (eqs)])
289
+ end
290
+
291
+ function validate (eqs:: Vector , noise:: Matrix ; info:: String = " " )
292
+ all ([validate (eqs[idx], noise[idx, :], info = info * " in eq. #$idx " )
293
+ for idx in 1 : length (eqs)])
294
+ end
295
+
296
+ function validate (eqs:: Vector , term:: Symbolic ; info:: String = " " )
297
+ all ([validate (eqs[idx], term, info = info * " in eq. #$idx " ) for idx in 1 : length (eqs)])
298
+ end
299
+
300
+ validate (term:: Symbolic ) = safe_get_unit (term, " " ) != = nothing
301
+
302
+ """
303
+ Throws error if units of equations are invalid.
304
+ """
305
+ function check_units (:: Val{:Unitful} , eqs... )
306
+ validate (eqs... ) ||
307
+ throw (ValidationError (" Some equations had invalid units. See warnings for details." ))
308
+ end
309
+
310
+ # Model parsing functions for Unitful
311
+ function convert_units (varunits:: Unitful.FreeUnits , value)
312
+ Unitful. ustrip (varunits, value)
313
+ end
314
+
315
+ convert_units (:: Unitful.FreeUnits , value:: MT.NoValue ) = MT. NO_VALUE
316
+
317
+ function convert_units (varunits:: Unitful.FreeUnits , value:: AbstractArray{T} ) where {T}
318
+ Unitful. ustrip .(varunits, value)
319
+ end
320
+
321
+ convert_units (:: Unitful.FreeUnits , value:: Num ) = value
322
+
323
+ # Extend model parsing error handling to include Unitful.DimensionError
324
+ MT. _is_dimension_error (e:: Unitful.DimensionError ) = true
325
+
326
+ # Define Unitful time variables (moved from main module)
327
+ const t_unitful = let
328
+ MT. only (MT. @independent_variables t [unit = Unitful. u " s" ])
329
+ end
330
+ const D_unitful = MT. Differential (t_unitful)
331
+
332
+ # Create a UnitfulUnitCheck module for backward compatibility
333
+ module UnitfulUnitCheck
334
+ using .. ModelingToolkitUnitfulExt
335
+ # Re-export all functions from the extension for backward compatibility
336
+ const equivalent = ModelingToolkitUnitfulExt. equivalent
337
+ const unitless = ModelingToolkitUnitfulExt. unitless
338
+ const get_unit = ModelingToolkitUnitfulExt. get_unit
339
+ const get_literal_unit = ModelingToolkitUnitfulExt. get_literal_unit
340
+ const safe_get_unit = ModelingToolkitUnitfulExt. safe_get_unit
341
+ const validate = ModelingToolkitUnitfulExt. validate
342
+ const screen_unit = ModelingToolkitUnitfulExt. screen_unit
343
+ end
344
+
345
+ end # module
0 commit comments