1
1
module ModelingToolkitUnitfulExt
2
2
3
- __precompile__ (false )
4
-
5
3
using ModelingToolkit
6
4
using Unitful
7
- using Symbolics: Symbolic, value, issym, isadd, ismul, ispow, arguments, operation, iscall, getmetadata
5
+ using Symbolics: Symbolic, value
8
6
using SciMLBase
9
- using RecursiveArrayTools
10
- using JumpProcesses: MassActionJump, ConstantRateJump, VariableRateJump
11
7
12
8
# Import necessary types and functions from ModelingToolkit
13
- import ModelingToolkit: ValidationError, Connection, instream, JumpType, VariableUnit,
14
- get_systems, Conditional, Comparison, Differential,
15
- Integral, Num, check_units
9
+ import ModelingToolkit: ValidationError, _get_unittype, get_unit, screen_unit,
10
+ equivalent, _is_dimension_error, convert_units, check_units
16
11
17
12
const MT = ModelingToolkit
18
13
19
- # Method extension for Unitful unit detection
20
- # This adds a method for the specific case where we have a Unitful unit
21
- function MT. __get_scalar_unit_type (v)
22
- u = MT. __get_literal_unit (v)
23
- if u isa MT. DQ. AbstractQuantity
24
- return Val (:DynamicQuantities )
25
- elseif u isa Unitful. Unitlike
26
- return Val (:Unitful )
27
- end
28
- return nothing
14
+ # Add Unitful-specific unit type detection
15
+ function MT. _get_unittype (u:: Unitful.Unitlike )
16
+ return Val (:Unitful )
29
17
end
30
18
31
19
# Base operations for mixing Symbolic and Unitful
32
- Base.:* (x:: Union{Num, Symbolic} , y:: Unitful.AbstractQuantity ) = x * y
33
- Base.:/ (x:: Union{Num, Symbolic} , y:: Unitful.AbstractQuantity ) = x / y
20
+ Base.:* (x:: Union{MT.Num, Symbolic} , y:: Unitful.AbstractQuantity ) = x * y
21
+ Base.:/ (x:: Union{MT.Num, Symbolic} , y:: Unitful.AbstractQuantity ) = x / y
22
+
23
+ # Unitful-specific get_unit method
24
+ function MT. get_unit (x:: Unitful.Quantity )
25
+ return screen_unit (Unitful. unit (x))
26
+ end
34
27
35
- """
36
- Throw exception on invalid unit types, otherwise return argument.
37
- """
38
- function screen_unit (result)
39
- result isa Unitful. Unitlike ||
40
- throw (ValidationError (" Unit must be a subtype of Unitful.Unitlike, not $(typeof (result)) ." ))
28
+ # Unitful-specific screen_unit method
29
+ function MT. screen_unit (result:: Unitful.Unitlike )
41
30
result isa Unitful. ScalarUnits ||
42
31
throw (ValidationError (" Non-scalar units such as $result are not supported. Use a scalar unit instead." ))
43
32
result == Unitful. u " °" &&
44
33
throw (ValidationError (" Degrees are not supported. Use radians instead." ))
45
- result
46
- end
47
-
48
- """
49
- Test unit equivalence.
50
-
51
- Example of implemented behavior:
52
-
53
- ```julia
54
- using ModelingToolkit, Unitful
55
- MT = ModelingToolkit
56
- @parameters γ P [unit = u"MW"] E [unit = u"kJ"] τ [unit = u"ms"]
57
- @test MT.equivalent(u"MW", u"kJ/ms") # Understands prefixes
58
- @test !MT.equivalent(u"m", u"cm") # Units must be same magnitude
59
- @test MT.equivalent(MT.get_unit(P^γ), MT.get_unit((E / τ)^γ)) # Handles symbolic exponents
60
- ```
61
- """
62
- equivalent (x, y) = isequal (1 * x, 1 * y)
63
- const unitless = Unitful. unit (1 )
64
-
65
- """
66
- Find the unit of a symbolic item.
67
- """
68
- get_unit (x:: Real ) = unitless
69
- get_unit (x:: Unitful.Quantity ) = screen_unit (Unitful. unit (x))
70
- get_unit (x:: AbstractArray ) = map (get_unit, x)
71
- get_unit (x:: Num ) = get_unit (value (x))
72
- function get_unit (x:: Union{Symbolics.ArrayOp, Symbolics.Arr, Symbolics.CallWithMetadata} )
73
- get_literal_unit (x)
74
- end
75
- get_unit (op:: Differential , args) = get_unit (args[1 ]) / get_unit (op. x)
76
- get_unit (op:: typeof (getindex), args) = get_unit (args[1 ])
77
- get_unit (x:: SciMLBase.NullParameters ) = unitless
78
- get_unit (op:: typeof (instream), args) = get_unit (args[1 ])
79
-
80
- get_literal_unit (x) = screen_unit (getmetadata (x, VariableUnit, unitless))
81
-
82
- function get_unit (op, args) # Fallback
83
- result = op (1 .* get_unit .(args)... )
84
- try
85
- Unitful. unit (result)
86
- catch
87
- throw (ValidationError (" Unable to get unit for operation $op with arguments $args ." ))
88
- end
89
- end
90
-
91
- function get_unit (op:: Integral , args)
92
- unit = 1
93
- if op. domain. variables isa Vector
94
- for u in op. domain. variables
95
- unit *= get_unit (u)
96
- end
97
- else
98
- unit *= get_unit (op. domain. variables)
99
- end
100
- return get_unit (args[1 ]) * unit
34
+ return result
101
35
end
102
36
103
- function get_unit (op:: Conditional , args)
104
- terms = get_unit .(args)
105
- terms[1 ] == unitless ||
106
- throw (ValidationError (" , in $op , [$(terms[1 ]) ] is not dimensionless." ))
107
- equivalent (terms[2 ], terms[3 ]) ||
108
- throw (ValidationError (" , in $op , units [$(terms[2 ]) ] and [$(terms[3 ]) ] do not match." ))
109
- return terms[2 ]
37
+ # Unitful-specific equivalence check
38
+ function MT. equivalent (x:: Unitful.Unitlike , y:: Unitful.Unitlike )
39
+ return isequal (1 * x, 1 * y)
110
40
end
111
41
112
- function get_unit (op:: typeof (Symbolics. _mapreduce), args)
113
- if args[2 ] == +
114
- get_unit (args[3 ])
115
- else
116
- throw (ValidationError (" Unsupported array operation $op " ))
117
- end
118
- end
119
-
120
- function get_unit (op:: Comparison , args)
121
- terms = get_unit .(args)
122
- equivalent (terms[1 ], terms[2 ]) ||
123
- throw (ValidationError (" , in comparison $op , units [$(terms[1 ]) ] and [$(terms[2 ]) ] do not match." ))
124
- return unitless
125
- end
42
+ # Mixed equivalence checks
43
+ MT. equivalent (x:: Unitful.Unitlike , y) = isequal (1 * x, y)
44
+ MT. equivalent (x, y:: Unitful.Unitlike ) = isequal (x, 1 * y)
126
45
127
- function get_unit (x:: Symbolic )
128
- if issym (x)
129
- get_literal_unit (x)
130
- elseif isadd (x)
131
- terms = get_unit .(arguments (x))
132
- firstunit = terms[1 ]
133
- for other in terms[2 : end ]
134
- termlist = join (map (repr, terms), " , " )
135
- equivalent (other, firstunit) ||
136
- throw (ValidationError (" , in sum $x , units [$termlist ] do not match." ))
137
- end
138
- return firstunit
139
- elseif ispow (x)
140
- pargs = arguments (x)
141
- base, expon = get_unit .(pargs)
142
- @assert expon isa Unitful. DimensionlessUnits
143
- if base == unitless
144
- unitless
145
- else
146
- pargs[2 ] isa Number ? base^ pargs[2 ] : (1 * base)^ pargs[2 ]
147
- end
148
- elseif iscall (x)
149
- op = operation (x)
150
- if issym (op) || (iscall (op) && iscall (operation (op))) # Dependent variables, not function calls
151
- return screen_unit (getmetadata (x, VariableUnit, unitless)) # Like x(t) or x[i]
152
- elseif iscall (op) && ! iscall (operation (op))
153
- gp = getmetadata (x, Symbolics. GetindexParent, nothing ) # Like x[1](t)
154
- return screen_unit (getmetadata (gp, VariableUnit, unitless))
155
- end # Actual function calls:
156
- args = arguments (x)
157
- return get_unit (op, args)
158
- else # This function should only be reached by Terms, for which `iscall` is true
159
- throw (ArgumentError (" Unsupported value $x ." ))
160
- end
161
- end
46
+ # The safe_get_unit function stays in the main package and already handles DQ.DimensionError
47
+ # We just need to make sure it can handle Unitful.DimensionError too
48
+ # This will be handled by the main function's MethodError catch
162
49
163
- """
164
- Get unit of term, returning nothing & showing warning instead of throwing errors.
165
- """
166
- function safe_get_unit (term, info)
167
- side = nothing
168
- try
169
- side = get_unit (term)
170
- catch err
171
- if err isa Unitful. DimensionError
172
- @warn (" $info : $(err. x) and $(err. y) are not dimensionally compatible." )
173
- elseif err isa ValidationError
174
- @warn (info* err. message)
175
- elseif err isa MethodError
176
- @warn (" $info : no method matching $(err. f) for arguments $(typeof .(err. args)) ." )
177
- else
178
- rethrow ()
179
- end
180
- end
181
- side
182
- end
183
-
184
- function _validate (terms:: Vector , labels:: Vector{String} ; info:: String = " " )
185
- valid = true
186
- first_unit = nothing
187
- first_label = nothing
188
- for (term, label) in zip (terms, labels)
189
- equnit = safe_get_unit (term, info * label)
190
- if equnit === nothing
191
- valid = false
192
- elseif ! isequal (term, 0 )
193
- if first_unit === nothing
194
- first_unit = equnit
195
- first_label = label
196
- elseif ! equivalent (first_unit, equnit)
197
- valid = false
198
- @warn (" $info : units [$(first_unit) ] for $(first_label) and [$(equnit) ] for $(label) do not match." )
199
- end
200
- end
201
- end
202
- valid
203
- end
204
-
205
- function _validate (conn:: Connection ; info:: String = " " )
206
- valid = true
207
- syss = get_systems (conn)
208
- sys = first (syss)
209
- unks = MT. unknowns (sys)
210
- for i in 2 : length (syss)
211
- s = syss[i]
212
- _unks = MT. unknowns (s)
213
- if length (unks) != length (_unks)
214
- valid = false
215
- @warn (" $info : connected systems $(MT. nameof (sys)) and $(MT. nameof (s)) have $(length (unks)) and $(length (_unks)) unknowns, cannot connect." )
216
- continue
217
- end
218
- for (i, x) in enumerate (unks)
219
- j = findfirst (isequal (x), _unks)
220
- if j == nothing
221
- valid = false
222
- @warn (" $info : connected systems $(MT. nameof (sys)) and $(MT. nameof (s)) do not have the same unknowns." )
223
- else
224
- aunit = safe_get_unit (x, info * string (MT. nameof (sys)) * " #$i " )
225
- bunit = safe_get_unit (_unks[j], info * string (MT. nameof (s)) * " #$j " )
226
- if ! equivalent (aunit, bunit)
227
- valid = false
228
- @warn (" $info : connected system unknowns $x and $(_unks[j]) have mismatched units." )
229
- end
230
- end
231
- end
232
- end
233
- valid
234
- end
235
-
236
- function validate (jump:: Union{VariableRateJump, ConstantRateJump} , t:: Symbolic ; info:: String = " " )
237
- newinfo = replace (info, " eq." => " jump" )
238
- _validate ([jump. rate, 1 / t], [" rate" , " 1/t" ], info = newinfo) && # Assuming the rate is per time units
239
- validate (jump. affect!, info = newinfo)
240
- end
241
-
242
- function validate (jump:: MassActionJump , t:: Symbolic ; info:: String = " " )
243
- left_symbols = [x[1 ] for x in jump. reactant_stoch] # vector of pairs of symbol,int -> vector symbols
244
- net_symbols = [x[1 ] for x in jump. net_stoch]
245
- all_symbols = vcat (left_symbols, net_symbols)
246
- allgood = _validate (all_symbols, string .(all_symbols); info)
247
- n = sum (x -> x[2 ], jump. reactant_stoch, init = 0 )
248
- base_unitful = all_symbols[1 ] # all same, get first
249
- allgood && _validate ([jump. scaled_rates, 1 / (t * base_unitful^ n)],
250
- [" scaled_rates" , " 1/(t*reactants^$n ))" ]; info)
251
- end
252
-
253
- function validate (jumps:: Vector{JumpType} , t:: Symbolic )
254
- labels = [" in Mass Action Jumps," , " in Constant Rate Jumps," , " in Variable Rate Jumps," ]
255
- majs = filter (x -> x isa MassActionJump, jumps)
256
- crjs = filter (x -> x isa ConstantRateJump, jumps)
257
- vrjs = filter (x -> x isa VariableRateJump, jumps)
258
- splitjumps = [majs, crjs, vrjs]
259
- all ([validate (js, t; info) for (js, info) in zip (splitjumps, labels)])
260
- end
261
-
262
- function validate (eq:: MT.Equation ; info:: String = " " )
263
- if typeof (eq. lhs) == Connection
264
- _validate (eq. rhs; info)
265
- else
266
- _validate ([eq. lhs, eq. rhs], [" left" , " right" ]; info)
267
- end
268
- end
269
-
270
- function validate (eq:: MT.Equation , term:: Union{Symbolic, Unitful.Quantity, Num} ; info:: String = " " )
271
- _validate ([eq. lhs, eq. rhs, term], [" left" , " right" , " noise" ]; info)
272
- end
273
-
274
- function validate (eq:: MT.Equation , terms:: Vector ; info:: String = " " )
275
- _validate (vcat ([eq. lhs, eq. rhs], terms),
276
- vcat ([" left" , " right" ], " noise #" .* string .(1 : length (terms))); info)
277
- end
278
-
279
- """
280
- Returns true iff units of equations are valid.
281
- """
282
- function validate (eqs:: Vector ; info:: String = " " )
283
- all ([validate (eqs[idx], info = info * " in eq. #$idx " ) for idx in 1 : length (eqs)])
284
- end
285
-
286
- function validate (eqs:: Vector , noise:: Vector ; info:: String = " " )
287
- all ([validate (eqs[idx], noise[idx], info = info * " in eq. #$idx " )
288
- for idx in 1 : length (eqs)])
289
- end
290
-
291
- function validate (eqs:: Vector , noise:: Matrix ; info:: String = " " )
292
- all ([validate (eqs[idx], noise[idx, :], info = info * " in eq. #$idx " )
293
- for idx in 1 : length (eqs)])
294
- end
295
-
296
- function validate (eqs:: Vector , term:: Symbolic ; info:: String = " " )
297
- all ([validate (eqs[idx], term, info = info * " in eq. #$idx " ) for idx in 1 : length (eqs)])
298
- end
299
-
300
- validate (term:: Symbolic ) = safe_get_unit (term, " " ) != = nothing
301
-
302
- """
303
- Throws error if units of equations are invalid.
304
- """
305
- function check_units (:: Val{:Unitful} , eqs... )
306
- validate (eqs... ) ||
307
- throw (ValidationError (" Some equations had invalid units. See warnings for details." ))
308
- end
50
+ # Unitful-specific dimension error detection for model parsing
51
+ MT. _is_dimension_error (e:: Unitful.DimensionError ) = true
309
52
310
- # Model parsing functions for Unitful
311
- function convert_units (varunits:: Unitful.FreeUnits , value)
53
+ # Unitful-specific convert_units methods for model parsing
54
+ function MT . convert_units (varunits:: Unitful.FreeUnits , value)
312
55
Unitful. ustrip (varunits, value)
313
56
end
314
57
315
- convert_units (:: Unitful.FreeUnits , value:: MT.NoValue ) = MT. NO_VALUE
58
+ MT . convert_units (:: Unitful.FreeUnits , value:: MT.NoValue ) = MT. NO_VALUE
316
59
317
- function convert_units (varunits:: Unitful.FreeUnits , value:: AbstractArray{T} ) where {T}
60
+ function MT . convert_units (varunits:: Unitful.FreeUnits , value:: AbstractArray{T} ) where {T}
318
61
Unitful. ustrip .(varunits, value)
319
62
end
320
63
321
- convert_units (:: Unitful.FreeUnits , value:: Num ) = value
64
+ MT . convert_units (:: Unitful.FreeUnits , value:: MT. Num ) = value
322
65
323
- # Extend model parsing error handling to include Unitful.DimensionError
324
- MT. _is_dimension_error (e:: Unitful.DimensionError ) = true
66
+ # Unitful-specific check_units method
67
+ function MT. check_units (:: Val{:Unitful} , eqs... )
68
+ # Use the main package's validate function
69
+ MT. validate (eqs... ) ||
70
+ throw (ValidationError (" Some equations had invalid units. See warnings for details." ))
71
+ end
325
72
326
73
# Define Unitful time variables (moved from main module)
327
74
const t_unitful = let
328
75
MT. only (MT. @independent_variables t [unit = Unitful. u " s" ])
329
76
end
330
77
const D_unitful = MT. Differential (t_unitful)
331
78
332
- # Create a UnitfulUnitCheck module for backward compatibility
333
- module UnitfulUnitCheck
334
- using .. ModelingToolkitUnitfulExt
335
- # Re-export all functions from the extension for backward compatibility
336
- const equivalent = ModelingToolkitUnitfulExt . equivalent
337
- const unitless = ModelingToolkitUnitfulExt . unitless
338
- const get_unit = ModelingToolkitUnitfulExt . get_unit
339
- const get_literal_unit = ModelingToolkitUnitfulExt . get_literal_unit
340
- const safe_get_unit = ModelingToolkitUnitfulExt . safe_get_unit
341
- const validate = ModelingToolkitUnitfulExt . validate
342
- const screen_unit = ModelingToolkitUnitfulExt . screen_unit
343
- end
79
+ # For backward compatibility - provide UnitfulUnitCheck module interface
80
+ # Extensions can access all the main package functions through MT
81
+ const UnitfulUnitCheck = (
82
+ equivalent = MT . equivalent,
83
+ unitless = Unitful . unit ( 1 ),
84
+ get_unit = MT . get_unit,
85
+ get_literal_unit = MT . get_literal_unit,
86
+ safe_get_unit = MT . safe_get_unit,
87
+ validate = MT . validate,
88
+ screen_unit = MT . screen_unit,
89
+ ValidationError = ValidationError
90
+ )
344
91
345
92
end # module
0 commit comments