Skip to content

Refactor Unitful.jl usage to use package extensions #3869

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 40 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
abaad91
Refactor Unitful.jl usage to use package extensions
ChrisRackauckas Aug 4, 2025
17013eb
Refine Unitful extension: keep general unit functions in main package
ChrisRackauckas Aug 5, 2025
b24913f
Fix extension implementation: improve unit operations and test setup
ChrisRackauckas Aug 5, 2025
24f0183
Remove UnitfulUnitCheck internal interface - use ModelingToolkit func…
ChrisRackauckas Aug 5, 2025
f1dee4e
Fix major unit operation issues in extension
ChrisRackauckas Aug 5, 2025
dd8a2db
Revert "Fix major unit operation issues in extension"
ChrisRackauckas Aug 5, 2025
61a607a
Reorganize Unitful-related tests to Extensions group
ChrisRackauckas Aug 5, 2025
695e991
Export unit functions to fix DynamicQuantities test failures
ChrisRackauckas Aug 5, 2025
4729095
Fix DynamicQuantities test failures in extension refactor
ChrisRackauckas Aug 6, 2025
f2df1c5
Update src/ModelingToolkit.jl
ChrisRackauckas Aug 6, 2025
52679ac
Update src/systems/model_parsing.jl
ChrisRackauckas Aug 6, 2025
13985a2
Update src/ModelingToolkit.jl
ChrisRackauckas Aug 6, 2025
ccb192e
Update src/ModelingToolkit.jl
ChrisRackauckas Aug 6, 2025
3a0f123
Update src/systems/model_parsing.jl
ChrisRackauckas Aug 6, 2025
85ae964
Update src/systems/model_parsing.jl
ChrisRackauckas Aug 6, 2025
9e7247b
Update src/systems/unit_check.jl
ChrisRackauckas Aug 6, 2025
244679d
Update src/systems/validation.jl
ChrisRackauckas Aug 6, 2025
480f6a4
Update src/systems/unit_check.jl
ChrisRackauckas Aug 6, 2025
29ebb9f
Update src/systems/unit_check.jl
ChrisRackauckas Aug 6, 2025
19f63e8
Update src/systems/unit_check.jl
ChrisRackauckas Aug 6, 2025
1a34474
Delete unitful-extension-refactor.patch
ChrisRackauckas Aug 6, 2025
9d1a8ca
Update src/systems/model_parsing.jl
ChrisRackauckas Aug 6, 2025
8385a42
Fix missing UnitfulUnitCheck module by recreating it in extension
ChrisRackauckas Aug 6, 2025
01ddee0
Update ModelingToolkitUnitfulExt.jl
ChrisRackauckas Aug 7, 2025
baf2939
Update ModelingToolkitUnitfulExt.jl
ChrisRackauckas Aug 7, 2025
ae9312f
Update unit_check.jl
ChrisRackauckas Aug 7, 2025
3cad7b3
Update ext/ModelingToolkitUnitfulExt.jl
ChrisRackauckas Aug 7, 2025
e826e2e
Update unit_check.jl
ChrisRackauckas Aug 7, 2025
4b3c91b
Update ext/ModelingToolkitUnitfulExt.jl
ChrisRackauckas Aug 7, 2025
2b85ba3
Update ext/ModelingToolkitUnitfulExt.jl
ChrisRackauckas Aug 7, 2025
0f602b9
Update ext/ModelingToolkitUnitfulExt.jl
ChrisRackauckas Aug 7, 2025
7daa0e2
Update unit_check.jl
ChrisRackauckas Aug 7, 2025
00ad788
Update unit_check.jl
ChrisRackauckas Aug 8, 2025
ca31b12
Update unit_check.jl
ChrisRackauckas Aug 8, 2025
aa3b607
Update src/systems/unit_check.jl
ChrisRackauckas Aug 8, 2025
db66e2a
Update unit_check.jl
ChrisRackauckas Aug 8, 2025
85d119d
Update ModelingToolkitUnitfulExt.jl
ChrisRackauckas Aug 8, 2025
95f88ef
Update ext/ModelingToolkitUnitfulExt.jl
ChrisRackauckas Aug 8, 2025
f1d04bf
Update test/units.jl
ChrisRackauckas Aug 8, 2025
6bd166f
Update test/units.jl
ChrisRackauckas Aug 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
URIs = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[weakdeps]
BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
Expand All @@ -74,6 +73,7 @@ FMI = "14a09403-18e3-468f-ad8a-74f8dda2d9ac"
InfiniteOpt = "20393b10-9daf-11e9-18c9-8db751c92c57"
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
Pyomo = "0e8e1daf-01b5-4eba-a626-3897743a3816"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[extensions]
MTKBifurcationKitExt = "BifurcationKit"
Expand All @@ -83,6 +83,7 @@ MTKFMIExt = "FMI"
MTKInfiniteOptExt = "InfiniteOpt"
MTKLabelledArraysExt = "LabelledArrays"
MTKPyomoDynamicOptExt = "Pyomo"
ModelingToolkitUnitfulExt = "Unitful"

[compat]
ADTypes = "1.14.0"
Expand Down Expand Up @@ -165,7 +166,6 @@ SymbolicUtils = "3.26.1"
Symbolics = "6.40"
URIs = "1"
UnPack = "0.1, 1.0"
Unitful = "1.1"
julia = "1.9"

[extras]
Expand Down Expand Up @@ -205,6 +205,7 @@ StochasticDelayDiffEq = "29a0d76e-afc8-11e9-03a4-eda52ae4b960"
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[targets]
test = ["AmplNLWriter", "BenchmarkTools", "BoundaryValueDiffEqMIRK", "BoundaryValueDiffEqAscher", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve", "Logging", "OptimizationBase", "LinearSolve"]
test = ["AmplNLWriter", "BenchmarkTools", "BoundaryValueDiffEqMIRK", "BoundaryValueDiffEqAscher", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve", "Logging", "OptimizationBase", "LinearSolve", "Unitful"]
81 changes: 81 additions & 0 deletions ext/ModelingToolkitUnitfulExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
module ModelingToolkitUnitfulExt

using ModelingToolkit
using Unitful
using Symbolics: Symbolic, value
using SciMLBase

# Import necessary types and functions from ModelingToolkit
import ModelingToolkit: ValidationError, _get_unittype, get_unit, screen_unit,
equivalent, _is_dimension_error, convert_units, check_units

const MT = ModelingToolkit

# Add Unitful-specific unit type detection
function MT._get_unittype(u::Unitful.Unitlike)
return Val(:Unitful)
end

# Base operations for mixing Symbolic and Unitful
Base.:*(x::Union{MT.Num, Symbolic}, y::Unitful.AbstractQuantity) = x * y
Base.:/(x::Union{MT.Num, Symbolic}, y::Unitful.AbstractQuantity) = x / y

# Unitful-specific get_unit method
function MT.get_unit(x::Unitful.Quantity)
return screen_unit(Unitful.unit(x))
end

# Unitful-specific screen_unit method
function MT.screen_unit(result::Unitful.Unitlike)
result isa Unitful.ScalarUnits ||
throw(ValidationError("Non-scalar units such as $result are not supported. Use a scalar unit instead."))
result == Unitful.u"°" &&
throw(ValidationError("Degrees are not supported. Use radians instead."))
return result
end

# Unitful-specific equivalence check
function MT.equivalent(x::Unitful.Unitlike, y::Unitful.Unitlike)
return isequal(1 * x, 1 * y)
end

# Mixed equivalence checks
MT.equivalent(x::Unitful.Unitlike, y) = isequal(1 * x, y)
MT.equivalent(x, y::Unitful.Unitlike) = isequal(x, 1 * y)

# The safe_get_unit function stays in the main package and already handles DQ.DimensionError
# We just need to make sure it can handle Unitful.DimensionError too
# This will be handled by the main function's MethodError catch

# Unitful-specific dimension error detection for model parsing
MT._is_dimension_error(e::Unitful.DimensionError) = true

# Unitful-specific convert_units methods for model parsing
function MT.convert_units(varunits::Unitful.FreeUnits, value)
Unitful.ustrip(varunits, value)
end

MT.convert_units(::Unitful.FreeUnits, value::MT.NoValue) = MT.NO_VALUE

function MT.convert_units(varunits::Unitful.FreeUnits, value::AbstractArray{T}) where {T}
Unitful.ustrip.(varunits, value)
end

MT.convert_units(::Unitful.FreeUnits, value::MT.Num) = value

# Unitful-specific check_units method
function MT.check_units(::Val{:Unitful}, eqs...)
# Use the main package's validate function
MT.validate(eqs...) ||
throw(ValidationError("Some equations had invalid units. See warnings for details."))
end

# Define Unitful time variables (moved from main module)
const t_unitful = let
MT.only(MT.@independent_variables t [unit = Unitful.u"s"])
end
const D_unitful = MT.Differential(t_unitful)

# Extension loaded - all Unitful-specific functionality is now available

end # module
8 changes: 2 additions & 6 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ using InteractiveUtils
using JumpProcesses
using DataStructures
using Base.Threads
using Latexify, Unitful, ArrayInterface
using Latexify, ArrayInterface
using Setfield, ConstructionBase
import Libdl
using DocStringExtensions
Expand Down Expand Up @@ -93,7 +93,7 @@ export independent_variables, unknowns, observables, parameters, full_parameters
@reexport using UnPack
RuntimeGeneratedFunctions.init(@__MODULE__)

import DynamicQuantities, Unitful
import DynamicQuantities
const DQ = DynamicQuantities

import DifferentiationInterface as DI
Expand Down Expand Up @@ -232,15 +232,11 @@ include("deprecations.jl")
const t_nounits = let
only(@independent_variables t)
end
const t_unitful = let
only(@independent_variables t [unit = Unitful.u"s"])
end
const t = let
only(@independent_variables t [unit = DQ.u"s"])
end

const D_nounits = Differential(t_nounits)
const D_unitful = Differential(t_unitful)
const D = Differential(t)

export ODEFunction, convert_system_indepvar,
Expand Down
13 changes: 1 addition & 12 deletions src/systems/model_parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -890,17 +890,7 @@ function convert_units(
DynamicQuantities.SymbolicUnits.as_quantity(varunits), value))
end

function convert_units(varunits::Unitful.FreeUnits, value)
Unitful.ustrip(varunits, value)
end

convert_units(::Unitful.FreeUnits, value::NoValue) = NO_VALUE

function convert_units(varunits::Unitful.FreeUnits, value::AbstractArray{T}) where {T}
Unitful.ustrip.(varunits, value)
end

convert_units(::Unitful.FreeUnits, value::Num) = value

convert_units(::DynamicQuantities.Quantity, value::Num) = value

Expand All @@ -919,8 +909,7 @@ function parse_variable_arg(dict, mod, arg, varclass, kwargs, where_types)
try
$setdefault($vv, $convert_units($unit, $name))
catch e
if isa(e, $(DynamicQuantities.DimensionError)) ||
isa(e, $(Unitful.DimensionError))
if _is_dimension_error($e)
error("Unable to convert units for \'" * string(:($$vv)) * "\'")
elseif isa(e, MethodError)
error("No or invalid units provided for \'" * string(:($$vv)) *
Expand Down
31 changes: 21 additions & 10 deletions src/systems/unit_check.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@ function __get_literal_unit(x)
u = getmetadata(v, VariableUnit, nothing)
u isa DQ.AbstractQuantity ? screen_unit(u) : u
end
# Extensible unit type detection - extensions can add methods for specific unit types
_get_unittype(u) = nothing
_get_unittype(u::DQ.AbstractQuantity) = Val(:DynamicQuantities)

function __get_scalar_unit_type(v)
u = __get_literal_unit(v)
if u isa DQ.AbstractQuantity
return Val(:DynamicQuantities)
elseif u isa Unitful.Unitlike
return Val(:Unitful)
end
return nothing
return _get_unittype(u)
end
function __get_unit_type(vs′...)
for vs in vs′
Expand Down Expand Up @@ -79,9 +78,17 @@ get_unit(op::typeof(instream), args) = get_unit(args[1])
function get_unit(op, args) # Fallback
result = oneunit(op(get_unit.(args)...))
try
get_unit(result)
result = op(unit_args...)
# For operations that return a unit directly, return oneunit to get the unit structure
return oneunit(result)
catch
throw(ValidationError("Unable to get unit for operation $op with arguments $args."))
try
# Try with oneunit for numeric operations
result = oneunit(op(unit_args...))
return get_unit(result)
catch
throw(ValidationError("Unable to get unit for operation $op with arguments $args."))
end
end
end

Expand Down Expand Up @@ -167,6 +174,10 @@ function get_unit(x::Symbolic)
end
end

# Dimension error detection function - extensible for different unit systems
_is_dimension_error(e) = false # Default fallback
_is_dimension_error(e::DQ.DimensionError) = true

"""
Get unit of term, returning nothing & showing warning instead of throwing errors.
"""
Expand All @@ -175,8 +186,8 @@ function safe_get_unit(term, info)
try
side = get_unit(term)
catch err
if err isa DQ.DimensionError
@warn("$info: $(err.x) and $(err.y) are not dimensionally compatible.")
if _is_dimension_error(err)
@warn("$info: dimension error occurred.")
elseif err isa ValidationError
@warn(info*err.message)
elseif err isa MethodError
Expand Down
Loading
Loading