Skip to content

Refactor Unitful.jl usage to use package extensions #3869

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 40 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
abaad91
Refactor Unitful.jl usage to use package extensions
ChrisRackauckas Aug 4, 2025
17013eb
Refine Unitful extension: keep general unit functions in main package
ChrisRackauckas Aug 5, 2025
b24913f
Fix extension implementation: improve unit operations and test setup
ChrisRackauckas Aug 5, 2025
24f0183
Remove UnitfulUnitCheck internal interface - use ModelingToolkit func…
ChrisRackauckas Aug 5, 2025
f1dee4e
Fix major unit operation issues in extension
ChrisRackauckas Aug 5, 2025
dd8a2db
Revert "Fix major unit operation issues in extension"
ChrisRackauckas Aug 5, 2025
61a607a
Reorganize Unitful-related tests to Extensions group
ChrisRackauckas Aug 5, 2025
695e991
Export unit functions to fix DynamicQuantities test failures
ChrisRackauckas Aug 5, 2025
4729095
Fix DynamicQuantities test failures in extension refactor
ChrisRackauckas Aug 6, 2025
f2df1c5
Update src/ModelingToolkit.jl
ChrisRackauckas Aug 6, 2025
52679ac
Update src/systems/model_parsing.jl
ChrisRackauckas Aug 6, 2025
13985a2
Update src/ModelingToolkit.jl
ChrisRackauckas Aug 6, 2025
ccb192e
Update src/ModelingToolkit.jl
ChrisRackauckas Aug 6, 2025
3a0f123
Update src/systems/model_parsing.jl
ChrisRackauckas Aug 6, 2025
85ae964
Update src/systems/model_parsing.jl
ChrisRackauckas Aug 6, 2025
9e7247b
Update src/systems/unit_check.jl
ChrisRackauckas Aug 6, 2025
244679d
Update src/systems/validation.jl
ChrisRackauckas Aug 6, 2025
480f6a4
Update src/systems/unit_check.jl
ChrisRackauckas Aug 6, 2025
29ebb9f
Update src/systems/unit_check.jl
ChrisRackauckas Aug 6, 2025
19f63e8
Update src/systems/unit_check.jl
ChrisRackauckas Aug 6, 2025
1a34474
Delete unitful-extension-refactor.patch
ChrisRackauckas Aug 6, 2025
9d1a8ca
Update src/systems/model_parsing.jl
ChrisRackauckas Aug 6, 2025
8385a42
Fix missing UnitfulUnitCheck module by recreating it in extension
ChrisRackauckas Aug 6, 2025
01ddee0
Update ModelingToolkitUnitfulExt.jl
ChrisRackauckas Aug 7, 2025
baf2939
Update ModelingToolkitUnitfulExt.jl
ChrisRackauckas Aug 7, 2025
ae9312f
Update unit_check.jl
ChrisRackauckas Aug 7, 2025
3cad7b3
Update ext/ModelingToolkitUnitfulExt.jl
ChrisRackauckas Aug 7, 2025
e826e2e
Update unit_check.jl
ChrisRackauckas Aug 7, 2025
4b3c91b
Update ext/ModelingToolkitUnitfulExt.jl
ChrisRackauckas Aug 7, 2025
2b85ba3
Update ext/ModelingToolkitUnitfulExt.jl
ChrisRackauckas Aug 7, 2025
0f602b9
Update ext/ModelingToolkitUnitfulExt.jl
ChrisRackauckas Aug 7, 2025
7daa0e2
Update unit_check.jl
ChrisRackauckas Aug 7, 2025
00ad788
Update unit_check.jl
ChrisRackauckas Aug 8, 2025
ca31b12
Update unit_check.jl
ChrisRackauckas Aug 8, 2025
aa3b607
Update src/systems/unit_check.jl
ChrisRackauckas Aug 8, 2025
db66e2a
Update unit_check.jl
ChrisRackauckas Aug 8, 2025
85d119d
Update ModelingToolkitUnitfulExt.jl
ChrisRackauckas Aug 8, 2025
95f88ef
Update ext/ModelingToolkitUnitfulExt.jl
ChrisRackauckas Aug 8, 2025
f1d04bf
Update test/units.jl
ChrisRackauckas Aug 8, 2025
6bd166f
Update test/units.jl
ChrisRackauckas Aug 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
URIs = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[weakdeps]
BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
Expand All @@ -74,6 +73,7 @@ FMI = "14a09403-18e3-468f-ad8a-74f8dda2d9ac"
InfiniteOpt = "20393b10-9daf-11e9-18c9-8db751c92c57"
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
Pyomo = "0e8e1daf-01b5-4eba-a626-3897743a3816"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[extensions]
MTKBifurcationKitExt = "BifurcationKit"
Expand All @@ -83,6 +83,7 @@ MTKFMIExt = "FMI"
MTKInfiniteOptExt = "InfiniteOpt"
MTKLabelledArraysExt = "LabelledArrays"
MTKPyomoDynamicOptExt = "Pyomo"
ModelingToolkitUnitfulExt = "Unitful"

[compat]
ADTypes = "1.14.0"
Expand Down Expand Up @@ -165,7 +166,6 @@ SymbolicUtils = "3.26.1"
Symbolics = "6.40"
URIs = "1"
UnPack = "0.1, 1.0"
Unitful = "1.1"
julia = "1.9"

[extras]
Expand Down
345 changes: 345 additions & 0 deletions ext/ModelingToolkitUnitfulExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,345 @@
module ModelingToolkitUnitfulExt

__precompile__(false)

using ModelingToolkit
using Unitful
using Symbolics: Symbolic, value, issym, isadd, ismul, ispow, arguments, operation, iscall, getmetadata
using SciMLBase
using RecursiveArrayTools
using JumpProcesses: MassActionJump, ConstantRateJump, VariableRateJump

# Import necessary types and functions from ModelingToolkit
import ModelingToolkit: ValidationError, Connection, instream, JumpType, VariableUnit,
get_systems, Conditional, Comparison, Differential,
Integral, Num, check_units

const MT = ModelingToolkit

# Method extension for Unitful unit detection
# This adds a method for the specific case where we have a Unitful unit
function MT.__get_scalar_unit_type(v)
u = MT.__get_literal_unit(v)
if u isa MT.DQ.AbstractQuantity
return Val(:DynamicQuantities)
elseif u isa Unitful.Unitlike
return Val(:Unitful)
end
return nothing
end

# Base operations for mixing Symbolic and Unitful
Base.:*(x::Union{Num, Symbolic}, y::Unitful.AbstractQuantity) = x * y
Base.:/(x::Union{Num, Symbolic}, y::Unitful.AbstractQuantity) = x / y

"""
Throw exception on invalid unit types, otherwise return argument.
"""
function screen_unit(result)
result isa Unitful.Unitlike ||
throw(ValidationError("Unit must be a subtype of Unitful.Unitlike, not $(typeof(result))."))
result isa Unitful.ScalarUnits ||
throw(ValidationError("Non-scalar units such as $result are not supported. Use a scalar unit instead."))
result == Unitful.u"°" &&
throw(ValidationError("Degrees are not supported. Use radians instead."))
result
end

"""
Test unit equivalence.

Example of implemented behavior:

```julia
using ModelingToolkit, Unitful
MT = ModelingToolkit
@parameters γ P [unit = u"MW"] E [unit = u"kJ"] τ [unit = u"ms"]
@test MT.equivalent(u"MW", u"kJ/ms") # Understands prefixes
@test !MT.equivalent(u"m", u"cm") # Units must be same magnitude
@test MT.equivalent(MT.get_unit(P^γ), MT.get_unit((E / τ)^γ)) # Handles symbolic exponents
```
"""
equivalent(x, y) = isequal(1 * x, 1 * y)
const unitless = Unitful.unit(1)

"""
Find the unit of a symbolic item.
"""
get_unit(x::Real) = unitless
get_unit(x::Unitful.Quantity) = screen_unit(Unitful.unit(x))
get_unit(x::AbstractArray) = map(get_unit, x)
get_unit(x::Num) = get_unit(value(x))
function get_unit(x::Union{Symbolics.ArrayOp, Symbolics.Arr, Symbolics.CallWithMetadata})
get_literal_unit(x)
end
get_unit(op::Differential, args) = get_unit(args[1]) / get_unit(op.x)
get_unit(op::typeof(getindex), args) = get_unit(args[1])
get_unit(x::SciMLBase.NullParameters) = unitless
get_unit(op::typeof(instream), args) = get_unit(args[1])

get_literal_unit(x) = screen_unit(getmetadata(x, VariableUnit, unitless))

function get_unit(op, args) # Fallback
result = op(1 .* get_unit.(args)...)
try
Unitful.unit(result)
catch
throw(ValidationError("Unable to get unit for operation $op with arguments $args."))
end
end

function get_unit(op::Integral, args)
unit = 1
if op.domain.variables isa Vector
for u in op.domain.variables
unit *= get_unit(u)
end
else
unit *= get_unit(op.domain.variables)
end
return get_unit(args[1]) * unit
end

function get_unit(op::Conditional, args)
terms = get_unit.(args)
terms[1] == unitless ||
throw(ValidationError(", in $op, [$(terms[1])] is not dimensionless."))
equivalent(terms[2], terms[3]) ||
throw(ValidationError(", in $op, units [$(terms[2])] and [$(terms[3])] do not match."))
return terms[2]
end

function get_unit(op::typeof(Symbolics._mapreduce), args)
if args[2] == +
get_unit(args[3])
else
throw(ValidationError("Unsupported array operation $op"))
end
end

function get_unit(op::Comparison, args)
terms = get_unit.(args)
equivalent(terms[1], terms[2]) ||
throw(ValidationError(", in comparison $op, units [$(terms[1])] and [$(terms[2])] do not match."))
return unitless
end

function get_unit(x::Symbolic)
if issym(x)
get_literal_unit(x)
elseif isadd(x)
terms = get_unit.(arguments(x))
firstunit = terms[1]
for other in terms[2:end]
termlist = join(map(repr, terms), ", ")
equivalent(other, firstunit) ||
throw(ValidationError(", in sum $x, units [$termlist] do not match."))
end
return firstunit
elseif ispow(x)
pargs = arguments(x)
base, expon = get_unit.(pargs)
@assert expon isa Unitful.DimensionlessUnits
if base == unitless
unitless
else
pargs[2] isa Number ? base^pargs[2] : (1 * base)^pargs[2]
end
elseif iscall(x)
op = operation(x)
if issym(op) || (iscall(op) && iscall(operation(op))) # Dependent variables, not function calls
return screen_unit(getmetadata(x, VariableUnit, unitless)) # Like x(t) or x[i]
elseif iscall(op) && !iscall(operation(op))
gp = getmetadata(x, Symbolics.GetindexParent, nothing) # Like x[1](t)
return screen_unit(getmetadata(gp, VariableUnit, unitless))
end # Actual function calls:
args = arguments(x)
return get_unit(op, args)
else # This function should only be reached by Terms, for which `iscall` is true
throw(ArgumentError("Unsupported value $x."))
end
end

"""
Get unit of term, returning nothing & showing warning instead of throwing errors.
"""
function safe_get_unit(term, info)
side = nothing
try
side = get_unit(term)
catch err
if err isa Unitful.DimensionError
@warn("$info: $(err.x) and $(err.y) are not dimensionally compatible.")
elseif err isa ValidationError
@warn(info*err.message)
elseif err isa MethodError
@warn("$info: no method matching $(err.f) for arguments $(typeof.(err.args)).")
else
rethrow()
end
end
side
end

function _validate(terms::Vector, labels::Vector{String}; info::String = "")
valid = true
first_unit = nothing
first_label = nothing
for (term, label) in zip(terms, labels)
equnit = safe_get_unit(term, info * label)
if equnit === nothing
valid = false
elseif !isequal(term, 0)
if first_unit === nothing
first_unit = equnit
first_label = label
elseif !equivalent(first_unit, equnit)
valid = false
@warn("$info: units [$(first_unit)] for $(first_label) and [$(equnit)] for $(label) do not match.")
end
end
end
valid
end

function _validate(conn::Connection; info::String = "")
valid = true
syss = get_systems(conn)
sys = first(syss)
unks = MT.unknowns(sys)
for i in 2:length(syss)
s = syss[i]
_unks = MT.unknowns(s)
if length(unks) != length(_unks)
valid = false
@warn("$info: connected systems $(MT.nameof(sys)) and $(MT.nameof(s)) have $(length(unks)) and $(length(_unks)) unknowns, cannot connect.")
continue
end
for (i, x) in enumerate(unks)
j = findfirst(isequal(x), _unks)
if j == nothing
valid = false
@warn("$info: connected systems $(MT.nameof(sys)) and $(MT.nameof(s)) do not have the same unknowns.")
else
aunit = safe_get_unit(x, info * string(MT.nameof(sys)) * "#$i")
bunit = safe_get_unit(_unks[j], info * string(MT.nameof(s)) * "#$j")
if !equivalent(aunit, bunit)
valid = false
@warn("$info: connected system unknowns $x and $(_unks[j]) have mismatched units.")
end
end
end
end
valid
end

function validate(jump::Union{VariableRateJump, ConstantRateJump}, t::Symbolic; info::String = "")
newinfo = replace(info, "eq." => "jump")
_validate([jump.rate, 1 / t], ["rate", "1/t"], info = newinfo) && # Assuming the rate is per time units
validate(jump.affect!, info = newinfo)
end

function validate(jump::MassActionJump, t::Symbolic; info::String = "")
left_symbols = [x[1] for x in jump.reactant_stoch] #vector of pairs of symbol,int -> vector symbols
net_symbols = [x[1] for x in jump.net_stoch]
all_symbols = vcat(left_symbols, net_symbols)
allgood = _validate(all_symbols, string.(all_symbols); info)
n = sum(x -> x[2], jump.reactant_stoch, init = 0)
base_unitful = all_symbols[1] #all same, get first
allgood && _validate([jump.scaled_rates, 1 / (t * base_unitful^n)],
["scaled_rates", "1/(t*reactants^$n))"]; info)
end

function validate(jumps::Vector{JumpType}, t::Symbolic)
labels = ["in Mass Action Jumps,", "in Constant Rate Jumps,", "in Variable Rate Jumps,"]
majs = filter(x -> x isa MassActionJump, jumps)
crjs = filter(x -> x isa ConstantRateJump, jumps)
vrjs = filter(x -> x isa VariableRateJump, jumps)
splitjumps = [majs, crjs, vrjs]
all([validate(js, t; info) for (js, info) in zip(splitjumps, labels)])
end

function validate(eq::MT.Equation; info::String = "")
if typeof(eq.lhs) == Connection
_validate(eq.rhs; info)
else
_validate([eq.lhs, eq.rhs], ["left", "right"]; info)
end
end

function validate(eq::MT.Equation, term::Union{Symbolic, Unitful.Quantity, Num}; info::String = "")
_validate([eq.lhs, eq.rhs, term], ["left", "right", "noise"]; info)
end

function validate(eq::MT.Equation, terms::Vector; info::String = "")
_validate(vcat([eq.lhs, eq.rhs], terms),
vcat(["left", "right"], "noise #" .* string.(1:length(terms))); info)
end

"""
Returns true iff units of equations are valid.
"""
function validate(eqs::Vector; info::String = "")
all([validate(eqs[idx], info = info * " in eq. #$idx") for idx in 1:length(eqs)])
end

function validate(eqs::Vector, noise::Vector; info::String = "")
all([validate(eqs[idx], noise[idx], info = info * " in eq. #$idx")
for idx in 1:length(eqs)])
end

function validate(eqs::Vector, noise::Matrix; info::String = "")
all([validate(eqs[idx], noise[idx, :], info = info * " in eq. #$idx")
for idx in 1:length(eqs)])
end

function validate(eqs::Vector, term::Symbolic; info::String = "")
all([validate(eqs[idx], term, info = info * " in eq. #$idx") for idx in 1:length(eqs)])
end

validate(term::Symbolic) = safe_get_unit(term, "") !== nothing

"""
Throws error if units of equations are invalid.
"""
function check_units(::Val{:Unitful}, eqs...)
validate(eqs...) ||
throw(ValidationError("Some equations had invalid units. See warnings for details."))
end

# Model parsing functions for Unitful
function convert_units(varunits::Unitful.FreeUnits, value)
Unitful.ustrip(varunits, value)
end

convert_units(::Unitful.FreeUnits, value::MT.NoValue) = MT.NO_VALUE

function convert_units(varunits::Unitful.FreeUnits, value::AbstractArray{T}) where {T}
Unitful.ustrip.(varunits, value)
end

convert_units(::Unitful.FreeUnits, value::Num) = value

# Extend model parsing error handling to include Unitful.DimensionError
MT._is_dimension_error(e::Unitful.DimensionError) = true

# Define Unitful time variables (moved from main module)
const t_unitful = let
MT.only(MT.@independent_variables t [unit = Unitful.u"s"])
end
const D_unitful = MT.Differential(t_unitful)

# Create a UnitfulUnitCheck module for backward compatibility
module UnitfulUnitCheck
using ..ModelingToolkitUnitfulExt
# Re-export all functions from the extension for backward compatibility
const equivalent = ModelingToolkitUnitfulExt.equivalent
const unitless = ModelingToolkitUnitfulExt.unitless
const get_unit = ModelingToolkitUnitfulExt.get_unit
const get_literal_unit = ModelingToolkitUnitfulExt.get_literal_unit
const safe_get_unit = ModelingToolkitUnitfulExt.safe_get_unit
const validate = ModelingToolkitUnitfulExt.validate
const screen_unit = ModelingToolkitUnitfulExt.screen_unit
end

end # module
Loading
Loading