Skip to content

Commit 17013eb

Browse files
Refine Unitful extension: keep general unit functions in main package
This commit fixes the extension approach to be more surgical: - Keep all general unit functions (get_unit, validate, etc.) in main package - Only move Unitful-specific dispatches to extension - Use proper multiple dispatch with _get_unittype stub function - Fix method overwriting issues and compilation problems Key changes: - src/systems/unit_check.jl: Add _get_unittype extensible function - ext/ModelingToolkitUnitfulExt.jl: Only Unitful-specific methods - Remove method overwriting of _is_dimension_error - Fix __init__ function issues The extension now properly extends the main package without replacing core functionality, following Julia extension best practices. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent abaad91 commit 17013eb

File tree

4 files changed

+61
-319
lines changed

4 files changed

+61
-319
lines changed

ext/ModelingToolkitUnitfulExt.jl

Lines changed: 50 additions & 303 deletions
Original file line numberDiff line numberDiff line change
@@ -1,345 +1,92 @@
11
module ModelingToolkitUnitfulExt
22

3-
__precompile__(false)
4-
53
using ModelingToolkit
64
using Unitful
7-
using Symbolics: Symbolic, value, issym, isadd, ismul, ispow, arguments, operation, iscall, getmetadata
5+
using Symbolics: Symbolic, value
86
using SciMLBase
9-
using RecursiveArrayTools
10-
using JumpProcesses: MassActionJump, ConstantRateJump, VariableRateJump
117

128
# Import necessary types and functions from ModelingToolkit
13-
import ModelingToolkit: ValidationError, Connection, instream, JumpType, VariableUnit,
14-
get_systems, Conditional, Comparison, Differential,
15-
Integral, Num, check_units
9+
import ModelingToolkit: ValidationError, _get_unittype, get_unit, screen_unit,
10+
equivalent, _is_dimension_error, convert_units, check_units
1611

1712
const MT = ModelingToolkit
1813

19-
# Method extension for Unitful unit detection
20-
# This adds a method for the specific case where we have a Unitful unit
21-
function MT.__get_scalar_unit_type(v)
22-
u = MT.__get_literal_unit(v)
23-
if u isa MT.DQ.AbstractQuantity
24-
return Val(:DynamicQuantities)
25-
elseif u isa Unitful.Unitlike
26-
return Val(:Unitful)
27-
end
28-
return nothing
14+
# Add Unitful-specific unit type detection
15+
function MT._get_unittype(u::Unitful.Unitlike)
16+
return Val(:Unitful)
2917
end
3018

3119
# Base operations for mixing Symbolic and Unitful
32-
Base.:*(x::Union{Num, Symbolic}, y::Unitful.AbstractQuantity) = x * y
33-
Base.:/(x::Union{Num, Symbolic}, y::Unitful.AbstractQuantity) = x / y
20+
Base.:*(x::Union{MT.Num, Symbolic}, y::Unitful.AbstractQuantity) = x * y
21+
Base.:/(x::Union{MT.Num, Symbolic}, y::Unitful.AbstractQuantity) = x / y
22+
23+
# Unitful-specific get_unit method
24+
function MT.get_unit(x::Unitful.Quantity)
25+
return screen_unit(Unitful.unit(x))
26+
end
3427

35-
"""
36-
Throw exception on invalid unit types, otherwise return argument.
37-
"""
38-
function screen_unit(result)
39-
result isa Unitful.Unitlike ||
40-
throw(ValidationError("Unit must be a subtype of Unitful.Unitlike, not $(typeof(result))."))
28+
# Unitful-specific screen_unit method
29+
function MT.screen_unit(result::Unitful.Unitlike)
4130
result isa Unitful.ScalarUnits ||
4231
throw(ValidationError("Non-scalar units such as $result are not supported. Use a scalar unit instead."))
4332
result == Unitful.u"°" &&
4433
throw(ValidationError("Degrees are not supported. Use radians instead."))
45-
result
46-
end
47-
48-
"""
49-
Test unit equivalence.
50-
51-
Example of implemented behavior:
52-
53-
```julia
54-
using ModelingToolkit, Unitful
55-
MT = ModelingToolkit
56-
@parameters γ P [unit = u"MW"] E [unit = u"kJ"] τ [unit = u"ms"]
57-
@test MT.equivalent(u"MW", u"kJ/ms") # Understands prefixes
58-
@test !MT.equivalent(u"m", u"cm") # Units must be same magnitude
59-
@test MT.equivalent(MT.get_unit(P^γ), MT.get_unit((E / τ)^γ)) # Handles symbolic exponents
60-
```
61-
"""
62-
equivalent(x, y) = isequal(1 * x, 1 * y)
63-
const unitless = Unitful.unit(1)
64-
65-
"""
66-
Find the unit of a symbolic item.
67-
"""
68-
get_unit(x::Real) = unitless
69-
get_unit(x::Unitful.Quantity) = screen_unit(Unitful.unit(x))
70-
get_unit(x::AbstractArray) = map(get_unit, x)
71-
get_unit(x::Num) = get_unit(value(x))
72-
function get_unit(x::Union{Symbolics.ArrayOp, Symbolics.Arr, Symbolics.CallWithMetadata})
73-
get_literal_unit(x)
74-
end
75-
get_unit(op::Differential, args) = get_unit(args[1]) / get_unit(op.x)
76-
get_unit(op::typeof(getindex), args) = get_unit(args[1])
77-
get_unit(x::SciMLBase.NullParameters) = unitless
78-
get_unit(op::typeof(instream), args) = get_unit(args[1])
79-
80-
get_literal_unit(x) = screen_unit(getmetadata(x, VariableUnit, unitless))
81-
82-
function get_unit(op, args) # Fallback
83-
result = op(1 .* get_unit.(args)...)
84-
try
85-
Unitful.unit(result)
86-
catch
87-
throw(ValidationError("Unable to get unit for operation $op with arguments $args."))
88-
end
89-
end
90-
91-
function get_unit(op::Integral, args)
92-
unit = 1
93-
if op.domain.variables isa Vector
94-
for u in op.domain.variables
95-
unit *= get_unit(u)
96-
end
97-
else
98-
unit *= get_unit(op.domain.variables)
99-
end
100-
return get_unit(args[1]) * unit
34+
return result
10135
end
10236

103-
function get_unit(op::Conditional, args)
104-
terms = get_unit.(args)
105-
terms[1] == unitless ||
106-
throw(ValidationError(", in $op, [$(terms[1])] is not dimensionless."))
107-
equivalent(terms[2], terms[3]) ||
108-
throw(ValidationError(", in $op, units [$(terms[2])] and [$(terms[3])] do not match."))
109-
return terms[2]
37+
# Unitful-specific equivalence check
38+
function MT.equivalent(x::Unitful.Unitlike, y::Unitful.Unitlike)
39+
return isequal(1 * x, 1 * y)
11040
end
11141

112-
function get_unit(op::typeof(Symbolics._mapreduce), args)
113-
if args[2] == +
114-
get_unit(args[3])
115-
else
116-
throw(ValidationError("Unsupported array operation $op"))
117-
end
118-
end
119-
120-
function get_unit(op::Comparison, args)
121-
terms = get_unit.(args)
122-
equivalent(terms[1], terms[2]) ||
123-
throw(ValidationError(", in comparison $op, units [$(terms[1])] and [$(terms[2])] do not match."))
124-
return unitless
125-
end
42+
# Mixed equivalence checks
43+
MT.equivalent(x::Unitful.Unitlike, y) = isequal(1 * x, y)
44+
MT.equivalent(x, y::Unitful.Unitlike) = isequal(x, 1 * y)
12645

127-
function get_unit(x::Symbolic)
128-
if issym(x)
129-
get_literal_unit(x)
130-
elseif isadd(x)
131-
terms = get_unit.(arguments(x))
132-
firstunit = terms[1]
133-
for other in terms[2:end]
134-
termlist = join(map(repr, terms), ", ")
135-
equivalent(other, firstunit) ||
136-
throw(ValidationError(", in sum $x, units [$termlist] do not match."))
137-
end
138-
return firstunit
139-
elseif ispow(x)
140-
pargs = arguments(x)
141-
base, expon = get_unit.(pargs)
142-
@assert expon isa Unitful.DimensionlessUnits
143-
if base == unitless
144-
unitless
145-
else
146-
pargs[2] isa Number ? base^pargs[2] : (1 * base)^pargs[2]
147-
end
148-
elseif iscall(x)
149-
op = operation(x)
150-
if issym(op) || (iscall(op) && iscall(operation(op))) # Dependent variables, not function calls
151-
return screen_unit(getmetadata(x, VariableUnit, unitless)) # Like x(t) or x[i]
152-
elseif iscall(op) && !iscall(operation(op))
153-
gp = getmetadata(x, Symbolics.GetindexParent, nothing) # Like x[1](t)
154-
return screen_unit(getmetadata(gp, VariableUnit, unitless))
155-
end # Actual function calls:
156-
args = arguments(x)
157-
return get_unit(op, args)
158-
else # This function should only be reached by Terms, for which `iscall` is true
159-
throw(ArgumentError("Unsupported value $x."))
160-
end
161-
end
46+
# The safe_get_unit function stays in the main package and already handles DQ.DimensionError
47+
# We just need to make sure it can handle Unitful.DimensionError too
48+
# This will be handled by the main function's MethodError catch
16249

163-
"""
164-
Get unit of term, returning nothing & showing warning instead of throwing errors.
165-
"""
166-
function safe_get_unit(term, info)
167-
side = nothing
168-
try
169-
side = get_unit(term)
170-
catch err
171-
if err isa Unitful.DimensionError
172-
@warn("$info: $(err.x) and $(err.y) are not dimensionally compatible.")
173-
elseif err isa ValidationError
174-
@warn(info*err.message)
175-
elseif err isa MethodError
176-
@warn("$info: no method matching $(err.f) for arguments $(typeof.(err.args)).")
177-
else
178-
rethrow()
179-
end
180-
end
181-
side
182-
end
183-
184-
function _validate(terms::Vector, labels::Vector{String}; info::String = "")
185-
valid = true
186-
first_unit = nothing
187-
first_label = nothing
188-
for (term, label) in zip(terms, labels)
189-
equnit = safe_get_unit(term, info * label)
190-
if equnit === nothing
191-
valid = false
192-
elseif !isequal(term, 0)
193-
if first_unit === nothing
194-
first_unit = equnit
195-
first_label = label
196-
elseif !equivalent(first_unit, equnit)
197-
valid = false
198-
@warn("$info: units [$(first_unit)] for $(first_label) and [$(equnit)] for $(label) do not match.")
199-
end
200-
end
201-
end
202-
valid
203-
end
204-
205-
function _validate(conn::Connection; info::String = "")
206-
valid = true
207-
syss = get_systems(conn)
208-
sys = first(syss)
209-
unks = MT.unknowns(sys)
210-
for i in 2:length(syss)
211-
s = syss[i]
212-
_unks = MT.unknowns(s)
213-
if length(unks) != length(_unks)
214-
valid = false
215-
@warn("$info: connected systems $(MT.nameof(sys)) and $(MT.nameof(s)) have $(length(unks)) and $(length(_unks)) unknowns, cannot connect.")
216-
continue
217-
end
218-
for (i, x) in enumerate(unks)
219-
j = findfirst(isequal(x), _unks)
220-
if j == nothing
221-
valid = false
222-
@warn("$info: connected systems $(MT.nameof(sys)) and $(MT.nameof(s)) do not have the same unknowns.")
223-
else
224-
aunit = safe_get_unit(x, info * string(MT.nameof(sys)) * "#$i")
225-
bunit = safe_get_unit(_unks[j], info * string(MT.nameof(s)) * "#$j")
226-
if !equivalent(aunit, bunit)
227-
valid = false
228-
@warn("$info: connected system unknowns $x and $(_unks[j]) have mismatched units.")
229-
end
230-
end
231-
end
232-
end
233-
valid
234-
end
235-
236-
function validate(jump::Union{VariableRateJump, ConstantRateJump}, t::Symbolic; info::String = "")
237-
newinfo = replace(info, "eq." => "jump")
238-
_validate([jump.rate, 1 / t], ["rate", "1/t"], info = newinfo) && # Assuming the rate is per time units
239-
validate(jump.affect!, info = newinfo)
240-
end
241-
242-
function validate(jump::MassActionJump, t::Symbolic; info::String = "")
243-
left_symbols = [x[1] for x in jump.reactant_stoch] #vector of pairs of symbol,int -> vector symbols
244-
net_symbols = [x[1] for x in jump.net_stoch]
245-
all_symbols = vcat(left_symbols, net_symbols)
246-
allgood = _validate(all_symbols, string.(all_symbols); info)
247-
n = sum(x -> x[2], jump.reactant_stoch, init = 0)
248-
base_unitful = all_symbols[1] #all same, get first
249-
allgood && _validate([jump.scaled_rates, 1 / (t * base_unitful^n)],
250-
["scaled_rates", "1/(t*reactants^$n))"]; info)
251-
end
252-
253-
function validate(jumps::Vector{JumpType}, t::Symbolic)
254-
labels = ["in Mass Action Jumps,", "in Constant Rate Jumps,", "in Variable Rate Jumps,"]
255-
majs = filter(x -> x isa MassActionJump, jumps)
256-
crjs = filter(x -> x isa ConstantRateJump, jumps)
257-
vrjs = filter(x -> x isa VariableRateJump, jumps)
258-
splitjumps = [majs, crjs, vrjs]
259-
all([validate(js, t; info) for (js, info) in zip(splitjumps, labels)])
260-
end
261-
262-
function validate(eq::MT.Equation; info::String = "")
263-
if typeof(eq.lhs) == Connection
264-
_validate(eq.rhs; info)
265-
else
266-
_validate([eq.lhs, eq.rhs], ["left", "right"]; info)
267-
end
268-
end
269-
270-
function validate(eq::MT.Equation, term::Union{Symbolic, Unitful.Quantity, Num}; info::String = "")
271-
_validate([eq.lhs, eq.rhs, term], ["left", "right", "noise"]; info)
272-
end
273-
274-
function validate(eq::MT.Equation, terms::Vector; info::String = "")
275-
_validate(vcat([eq.lhs, eq.rhs], terms),
276-
vcat(["left", "right"], "noise #" .* string.(1:length(terms))); info)
277-
end
278-
279-
"""
280-
Returns true iff units of equations are valid.
281-
"""
282-
function validate(eqs::Vector; info::String = "")
283-
all([validate(eqs[idx], info = info * " in eq. #$idx") for idx in 1:length(eqs)])
284-
end
285-
286-
function validate(eqs::Vector, noise::Vector; info::String = "")
287-
all([validate(eqs[idx], noise[idx], info = info * " in eq. #$idx")
288-
for idx in 1:length(eqs)])
289-
end
290-
291-
function validate(eqs::Vector, noise::Matrix; info::String = "")
292-
all([validate(eqs[idx], noise[idx, :], info = info * " in eq. #$idx")
293-
for idx in 1:length(eqs)])
294-
end
295-
296-
function validate(eqs::Vector, term::Symbolic; info::String = "")
297-
all([validate(eqs[idx], term, info = info * " in eq. #$idx") for idx in 1:length(eqs)])
298-
end
299-
300-
validate(term::Symbolic) = safe_get_unit(term, "") !== nothing
301-
302-
"""
303-
Throws error if units of equations are invalid.
304-
"""
305-
function check_units(::Val{:Unitful}, eqs...)
306-
validate(eqs...) ||
307-
throw(ValidationError("Some equations had invalid units. See warnings for details."))
308-
end
50+
# Unitful-specific dimension error detection for model parsing
51+
MT._is_dimension_error(e::Unitful.DimensionError) = true
30952

310-
# Model parsing functions for Unitful
311-
function convert_units(varunits::Unitful.FreeUnits, value)
53+
# Unitful-specific convert_units methods for model parsing
54+
function MT.convert_units(varunits::Unitful.FreeUnits, value)
31255
Unitful.ustrip(varunits, value)
31356
end
31457

315-
convert_units(::Unitful.FreeUnits, value::MT.NoValue) = MT.NO_VALUE
58+
MT.convert_units(::Unitful.FreeUnits, value::MT.NoValue) = MT.NO_VALUE
31659

317-
function convert_units(varunits::Unitful.FreeUnits, value::AbstractArray{T}) where {T}
60+
function MT.convert_units(varunits::Unitful.FreeUnits, value::AbstractArray{T}) where {T}
31861
Unitful.ustrip.(varunits, value)
31962
end
32063

321-
convert_units(::Unitful.FreeUnits, value::Num) = value
64+
MT.convert_units(::Unitful.FreeUnits, value::MT.Num) = value
32265

323-
# Extend model parsing error handling to include Unitful.DimensionError
324-
MT._is_dimension_error(e::Unitful.DimensionError) = true
66+
# Unitful-specific check_units method
67+
function MT.check_units(::Val{:Unitful}, eqs...)
68+
# Use the main package's validate function
69+
MT.validate(eqs...) ||
70+
throw(ValidationError("Some equations had invalid units. See warnings for details."))
71+
end
32572

32673
# Define Unitful time variables (moved from main module)
32774
const t_unitful = let
32875
MT.only(MT.@independent_variables t [unit = Unitful.u"s"])
32976
end
33077
const D_unitful = MT.Differential(t_unitful)
33178

332-
# Create a UnitfulUnitCheck module for backward compatibility
333-
module UnitfulUnitCheck
334-
using ..ModelingToolkitUnitfulExt
335-
# Re-export all functions from the extension for backward compatibility
336-
const equivalent = ModelingToolkitUnitfulExt.equivalent
337-
const unitless = ModelingToolkitUnitfulExt.unitless
338-
const get_unit = ModelingToolkitUnitfulExt.get_unit
339-
const get_literal_unit = ModelingToolkitUnitfulExt.get_literal_unit
340-
const safe_get_unit = ModelingToolkitUnitfulExt.safe_get_unit
341-
const validate = ModelingToolkitUnitfulExt.validate
342-
const screen_unit = ModelingToolkitUnitfulExt.screen_unit
343-
end
79+
# For backward compatibility - provide UnitfulUnitCheck module interface
80+
# Extensions can access all the main package functions through MT
81+
const UnitfulUnitCheck = (
82+
equivalent = MT.equivalent,
83+
unitless = Unitful.unit(1),
84+
get_unit = MT.get_unit,
85+
get_literal_unit = MT.get_literal_unit,
86+
safe_get_unit = MT.safe_get_unit,
87+
validate = MT.validate,
88+
screen_unit = MT.screen_unit,
89+
ValidationError = ValidationError
90+
)
34491

34592
end # module

0 commit comments

Comments
 (0)