diff --git a/Project.toml b/Project.toml index 79b499748c..4fb64f1e4f 100644 --- a/Project.toml +++ b/Project.toml @@ -64,7 +64,6 @@ SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" URIs = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" -Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [weakdeps] BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665" @@ -74,6 +73,7 @@ FMI = "14a09403-18e3-468f-ad8a-74f8dda2d9ac" InfiniteOpt = "20393b10-9daf-11e9-18c9-8db751c92c57" LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" Pyomo = "0e8e1daf-01b5-4eba-a626-3897743a3816" +Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [extensions] MTKBifurcationKitExt = "BifurcationKit" @@ -83,6 +83,7 @@ MTKFMIExt = "FMI" MTKInfiniteOptExt = "InfiniteOpt" MTKLabelledArraysExt = "LabelledArrays" MTKPyomoDynamicOptExt = "Pyomo" +ModelingToolkitUnitfulExt = "Unitful" [compat] ADTypes = "1.14.0" @@ -165,7 +166,6 @@ SymbolicUtils = "3.26.1" Symbolics = "6.40" URIs = "1" UnPack = "0.1, 1.0" -Unitful = "1.1" julia = "1.9" [extras] @@ -205,6 +205,7 @@ StochasticDelayDiffEq = "29a0d76e-afc8-11e9-03a4-eda52ae4b960" StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [targets] -test = ["AmplNLWriter", "BenchmarkTools", "BoundaryValueDiffEqMIRK", "BoundaryValueDiffEqAscher", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve", "Logging", "OptimizationBase", "LinearSolve"] +test = ["AmplNLWriter", "BenchmarkTools", "BoundaryValueDiffEqMIRK", "BoundaryValueDiffEqAscher", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve", "Logging", "OptimizationBase", "LinearSolve", "Unitful"] diff --git a/ext/ModelingToolkitUnitfulExt.jl b/ext/ModelingToolkitUnitfulExt.jl new file mode 100644 index 0000000000..5984fd510e --- /dev/null +++ b/ext/ModelingToolkitUnitfulExt.jl @@ -0,0 +1,103 @@ +module ModelingToolkitUnitfulExt + +using ModelingToolkit, Symbolics, SciMLBase, Unitful, RecursiveArrayTools +using ModelingToolkit: ValidationError, Connection, instream, JumpType, VariableUnit, + get_systems, Conditional, Comparison, Integral, Differential +using JumpProcesses: MassActionJump, ConstantRateJump, VariableRateJump +using Symbolics: Symbolic, value, issym, isadd, ismul, ispow, iscall, operation, arguments, getmetadata + +using Unitful +using SciMLBase + +# Import necessary types and functions from ModelingToolkit +import ModelingToolkit: ValidationError, _get_unittype, get_unit, screen_unit, + equivalent, _is_dimension_error, convert_units, check_units + +const MT = ModelingToolkit + +# Add Unitful-specific unit type detection +function MT._get_unittype(u::Unitful.Unitlike) + return Val(:Unitful) +end + +MT._oneunit(x::Unitful.FreeUnits) = 1 * x + +# Base operations for mixing Symbolic and Unitful +Base.:*(x::Union{MT.Num, Symbolic}, y::Unitful.AbstractQuantity) = x * y +Base.:/(x::Union{MT.Num, Symbolic}, y::Unitful.AbstractQuantity) = x / y + +# Unitful-specific get_unit method +function MT.get_unit(x::Unitful.Quantity) + return screen_unit(Unitful.unit(x)) +end + +# Unitful-specific screen_unit method +function MT.screen_unit(result::Unitful.Unitlike) + result isa Unitful.ScalarUnits || + throw(ValidationError("Non-scalar units such as $result are not supported. Use a scalar unit instead.")) + result == Unitful.u"°" && + throw(ValidationError("Degrees are not supported. Use radians instead.")) + return result +end + +# Unitful-specific equivalence check +function MT.equivalent(x::Unitful.Unitlike, y::Unitful.Unitlike) + return isequal(1 * x, 1 * y) +end + +# Mixed equivalence checks +MT.equivalent(x::Unitful.Unitlike, y) = isequal(1 * x, y) +MT.equivalent(x, y::Unitful.Unitlike) = isequal(x, 1 * y) + +# The safe_get_unit function stays in the main package and already handles DQ.DimensionError +# We just need to make sure it can handle Unitful.DimensionError too +# This will be handled by the main function's MethodError catch + +# Unitful-specific dimension error detection for model parsing +MT._is_dimension_error(e::Unitful.DimensionError) = true + +# Unitful-specific convert_units methods for model parsing +function MT.convert_units(varunits::Unitful.FreeUnits, value) + Unitful.ustrip(varunits, value) +end + +MT.convert_units(::Unitful.FreeUnits, value::MT.NoValue) = MT.NO_VALUE + +function MT.convert_units(varunits::Unitful.FreeUnits, value::AbstractArray{T}) where {T} + Unitful.ustrip.(varunits, value) +end + +MT.convert_units(::Unitful.FreeUnits, value::MT.Num) = value + +# Unitful-specific check_units method +function MT.check_units(::Val{:Unitful}, eqs...) + # Use the main package's validate function + MT.validate(eqs...) || + throw(ValidationError("Some equations had invalid units. See warnings for details.")) +end + +# Define Unitful time variables (moved from main module) +const t_unitful = let + MT.only(MT.@independent_variables t [unit = Unitful.u"s"]) +end +const D_unitful = MT.Differential(t_unitful) + + +""" +Throw exception on invalid unit types, otherwise return argument. +""" +function screen_unit(result::Unitful.AbstractQuantity) + result isa Unitful.Unitlike || + throw(ValidationError("Unit must be a subtype of Unitful.Unitlike, not $(typeof(result)).")) + result isa Unitful.ScalarUnits || + throw(ValidationError("Non-scalar units such as $result are not supported. Use a scalar unit instead.")) + result == Unitful.u"°" && + throw(ValidationError("Degrees are not supported. Use radians instead.")) + result +end + +const unitless = Unitful.unit(1) + + + +end # module UnitfulUnitCheck diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 2c259058b0..d56734c036 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -30,7 +30,7 @@ using InteractiveUtils using JumpProcesses using DataStructures using Base.Threads -using Latexify, Unitful, ArrayInterface +using Latexify, ArrayInterface using Setfield, ConstructionBase import Libdl using DocStringExtensions @@ -93,7 +93,7 @@ export independent_variables, unknowns, observables, parameters, full_parameters @reexport using UnPack RuntimeGeneratedFunctions.init(@__MODULE__) -import DynamicQuantities, Unitful +import DynamicQuantities const DQ = DynamicQuantities import DifferentiationInterface as DI @@ -232,15 +232,11 @@ include("deprecations.jl") const t_nounits = let only(@independent_variables t) end -const t_unitful = let - only(@independent_variables t [unit = Unitful.u"s"]) -end const t = let only(@independent_variables t [unit = DQ.u"s"]) end const D_nounits = Differential(t_nounits) -const D_unitful = Differential(t_unitful) const D = Differential(t) export ODEFunction, convert_system_indepvar, diff --git a/src/systems/model_parsing.jl b/src/systems/model_parsing.jl index c24c063ee0..fd35faecfb 100644 --- a/src/systems/model_parsing.jl +++ b/src/systems/model_parsing.jl @@ -890,17 +890,7 @@ function convert_units( DynamicQuantities.SymbolicUnits.as_quantity(varunits), value)) end -function convert_units(varunits::Unitful.FreeUnits, value) - Unitful.ustrip(varunits, value) -end - -convert_units(::Unitful.FreeUnits, value::NoValue) = NO_VALUE - -function convert_units(varunits::Unitful.FreeUnits, value::AbstractArray{T}) where {T} - Unitful.ustrip.(varunits, value) -end -convert_units(::Unitful.FreeUnits, value::Num) = value convert_units(::DynamicQuantities.Quantity, value::Num) = value @@ -919,8 +909,7 @@ function parse_variable_arg(dict, mod, arg, varclass, kwargs, where_types) try $setdefault($vv, $convert_units($unit, $name)) catch e - if isa(e, $(DynamicQuantities.DimensionError)) || - isa(e, $(Unitful.DimensionError)) + if $_is_dimension_error(e) error("Unable to convert units for \'" * string(:($$vv)) * "\'") elseif isa(e, MethodError) error("No or invalid units provided for \'" * string(:($$vv)) * diff --git a/src/systems/unit_check.jl b/src/systems/unit_check.jl index acf7451065..c9c1b77ed6 100644 --- a/src/systems/unit_check.jl +++ b/src/systems/unit_check.jl @@ -19,14 +19,13 @@ function __get_literal_unit(x) u = getmetadata(v, VariableUnit, nothing) u isa DQ.AbstractQuantity ? screen_unit(u) : u end +# Extensible unit type detection - extensions can add methods for specific unit types +_get_unittype(u) = nothing +_get_unittype(u::DQ.AbstractQuantity) = Val(:DynamicQuantities) + function __get_scalar_unit_type(v) u = __get_literal_unit(v) - if u isa DQ.AbstractQuantity - return Val(:DynamicQuantities) - elseif u isa Unitful.Unitlike - return Val(:Unitful) - end - return nothing + return _get_unittype(u) end function __get_unit_type(vs′...) for vs in vs′ @@ -44,24 +43,27 @@ function __get_unit_type(vs′...) return nothing end -function screen_unit(result) - if result isa DQ.AbstractQuantity - d = DQ.dimension(result) - if d isa DQ.Dimensions - return result - elseif d isa DQ.SymbolicDimensions - return DQ.uexpand(oneunit(result)) - else - throw(ValidationError("$result doesn't have a recognized unit")) - end + +function screen_unit(result::DQ.AbstractQuantity) + d = DQ.dimension(result) + if d isa DQ.Dimensions + return result + elseif d isa DQ.SymbolicDimensions + return DQ.uexpand(oneunit(result)) else - throw(ValidationError("$result doesn't have any unit.")) + throw(ValidationError("$result doesn't have a recognized unit")) end end +function screen_unit(result) + throw(ValidationError("$result doesn't have any unit.")) +end + const unitless = DQ.Quantity(1.0) get_literal_unit(x) = screen_unit(something(__get_literal_unit(x), unitless)) +_oneunit(x) = oneunit(x) + """ Find the unit of a symbolic item. """ @@ -69,7 +71,9 @@ get_unit(x::Real) = unitless get_unit(x::DQ.AbstractQuantity) = screen_unit(x) get_unit(x::AbstractArray) = map(get_unit, x) get_unit(x::Num) = get_unit(unwrap(x)) -get_unit(x::Symbolics.Arr) = get_unit(unwrap(x)) +function get_unit(x::Union{Symbolics.ArrayOp, Symbolics.Arr, Symbolics.CallWithMetadata}) + get_literal_unit(x) +end get_unit(op::Differential, args) = get_unit(args[1]) / get_unit(op.x) get_unit(op::Difference, args) = get_unit(args[1]) / get_unit(op.t) get_unit(op::typeof(getindex), args) = get_unit(args[1]) @@ -77,7 +81,7 @@ get_unit(x::SciMLBase.NullParameters) = unitless get_unit(op::typeof(instream), args) = get_unit(args[1]) function get_unit(op, args) # Fallback - result = oneunit(op(get_unit.(args)...)) + result = _oneunit(op(get_unit.(args)...)) try get_unit(result) catch @@ -104,7 +108,10 @@ function get_unit(op::Integral, args) return get_unit(args[1]) * unit end -equivalent(x, y) = isequal(x, y) +""" +Test unit equivalence. +""" +equivalent(x, y) = isequal(1 * x, 1 * y) function get_unit(op::Conditional, args) terms = get_unit.(args) terms[1] == unitless || @@ -167,6 +174,10 @@ function get_unit(x::Symbolic) end end +# Dimension error detection function - extensible for different unit systems +_is_dimension_error(e) = false # Default fallback +_is_dimension_error(e::DQ.DimensionError) = true + """ Get unit of term, returning nothing & showing warning instead of throwing errors. """ @@ -175,8 +186,8 @@ function safe_get_unit(term, info) try side = get_unit(term) catch err - if err isa DQ.DimensionError - @warn("$info: $(err.x) and $(err.y) are not dimensionally compatible.") + if _is_dimension_error(err) + @warn("$info: dimension error occurred.") elseif err isa ValidationError @warn(info*err.message) elseif err isa MethodError diff --git a/src/systems/validation.jl b/src/systems/validation.jl index d416a02ea2..aacba03f49 100644 --- a/src/systems/validation.jl +++ b/src/systems/validation.jl @@ -1,287 +1,3 @@ -module UnitfulUnitCheck - -using ..ModelingToolkit, Symbolics, SciMLBase, Unitful, RecursiveArrayTools -using ..ModelingToolkit: ValidationError, - ModelingToolkit, Connection, instream, JumpType, VariableUnit, - get_systems, - Conditional, Comparison -using JumpProcesses: MassActionJump, ConstantRateJump, VariableRateJump -using Symbolics: Symbolic, value, issym, isadd, ismul, ispow -const MT = ModelingToolkit - -Base.:*(x::Union{Num, Symbolic}, y::Unitful.AbstractQuantity) = x * y -Base.:/(x::Union{Num, Symbolic}, y::Unitful.AbstractQuantity) = x / y - -""" -Throw exception on invalid unit types, otherwise return argument. -""" -function screen_unit(result) - result isa Unitful.Unitlike || - throw(ValidationError("Unit must be a subtype of Unitful.Unitlike, not $(typeof(result)).")) - result isa Unitful.ScalarUnits || - throw(ValidationError("Non-scalar units such as $result are not supported. Use a scalar unit instead.")) - result == u"°" && - throw(ValidationError("Degrees are not supported. Use radians instead.")) - result -end - -""" -Test unit equivalence. - -Example of implemented behavior: - -```julia -using ModelingToolkit, Unitful -MT = ModelingToolkit -@parameters γ P [unit = u"MW"] E [unit = u"kJ"] τ [unit = u"ms"] -@test MT.equivalent(u"MW", u"kJ/ms") # Understands prefixes -@test !MT.equivalent(u"m", u"cm") # Units must be same magnitude -@test MT.equivalent(MT.get_unit(P^γ), MT.get_unit((E / τ)^γ)) # Handles symbolic exponents -``` -""" -equivalent(x, y) = isequal(1 * x, 1 * y) -const unitless = Unitful.unit(1) - -""" -Find the unit of a symbolic item. -""" -get_unit(x::Real) = unitless -get_unit(x::Unitful.Quantity) = screen_unit(Unitful.unit(x)) -get_unit(x::AbstractArray) = map(get_unit, x) -get_unit(x::Num) = get_unit(value(x)) -function get_unit(x::Union{Symbolics.ArrayOp, Symbolics.Arr, Symbolics.CallWithMetadata}) - get_literal_unit(x) -end -get_unit(op::Differential, args) = get_unit(args[1]) / get_unit(op.x) -get_unit(op::typeof(getindex), args) = get_unit(args[1]) -get_unit(x::SciMLBase.NullParameters) = unitless -get_unit(op::typeof(instream), args) = get_unit(args[1]) - -get_literal_unit(x) = screen_unit(getmetadata(x, VariableUnit, unitless)) - -function get_unit(op, args) # Fallback - result = op(1 .* get_unit.(args)...) - try - unit(result) - catch - throw(ValidationError("Unable to get unit for operation $op with arguments $args.")) - end -end - -function get_unit(op::Integral, args) - unit = 1 - if op.domain.variables isa Vector - for u in op.domain.variables - unit *= get_unit(u) - end - else - unit *= get_unit(op.domain.variables) - end - return get_unit(args[1]) * unit -end - -function get_unit(op::Conditional, args) - terms = get_unit.(args) - terms[1] == unitless || - throw(ValidationError(", in $op, [$(terms[1])] is not dimensionless.")) - equivalent(terms[2], terms[3]) || - throw(ValidationError(", in $op, units [$(terms[2])] and [$(terms[3])] do not match.")) - return terms[2] -end - -function get_unit(op::typeof(Symbolics._mapreduce), args) - if args[2] == + - get_unit(args[3]) - else - throw(ValidationError("Unsupported array operation $op")) - end -end - -function get_unit(op::Comparison, args) - terms = get_unit.(args) - equivalent(terms[1], terms[2]) || - throw(ValidationError(", in comparison $op, units [$(terms[1])] and [$(terms[2])] do not match.")) - return unitless -end - -function get_unit(x::Symbolic) - if issym(x) - get_literal_unit(x) - elseif isadd(x) - terms = get_unit.(arguments(x)) - firstunit = terms[1] - for other in terms[2:end] - termlist = join(map(repr, terms), ", ") - equivalent(other, firstunit) || - throw(ValidationError(", in sum $x, units [$termlist] do not match.")) - end - return firstunit - elseif ispow(x) - pargs = arguments(x) - base, expon = get_unit.(pargs) - @assert expon isa Unitful.DimensionlessUnits - if base == unitless - unitless - else - pargs[2] isa Number ? base^pargs[2] : (1 * base)^pargs[2] - end - elseif iscall(x) - op = operation(x) - if issym(op) || (iscall(op) && iscall(operation(op))) # Dependent variables, not function calls - return screen_unit(getmetadata(x, VariableUnit, unitless)) # Like x(t) or x[i] - elseif iscall(op) && !iscall(operation(op)) - gp = getmetadata(x, Symbolics.GetindexParent, nothing) # Like x[1](t) - return screen_unit(getmetadata(gp, VariableUnit, unitless)) - end # Actual function calls: - args = arguments(x) - return get_unit(op, args) - else # This function should only be reached by Terms, for which `iscall` is true - throw(ArgumentError("Unsupported value $x.")) - end -end - -""" -Get unit of term, returning nothing & showing warning instead of throwing errors. -""" -function safe_get_unit(term, info) - side = nothing - try - side = get_unit(term) - catch err - if err isa Unitful.DimensionError - @warn("$info: $(err.x) and $(err.y) are not dimensionally compatible.") - elseif err isa ValidationError - @warn(info*err.message) - elseif err isa MethodError - @warn("$info: no method matching $(err.f) for arguments $(typeof.(err.args)).") - else - rethrow() - end - end - side -end - -function _validate(terms::Vector, labels::Vector{String}; info::String = "") - valid = true - first_unit = nothing - first_label = nothing - for (term, label) in zip(terms, labels) - equnit = safe_get_unit(term, info * label) - if equnit === nothing - valid = false - elseif !isequal(term, 0) - if first_unit === nothing - first_unit = equnit - first_label = label - elseif !equivalent(first_unit, equnit) - valid = false - @warn("$info: units [$(first_unit)] for $(first_label) and [$(equnit)] for $(label) do not match.") - end - end - end - valid -end - -function _validate(conn::Connection; info::String = "") - valid = true - syss = get_systems(conn) - sys = first(syss) - unks = unknowns(sys) - for i in 2:length(syss) - s = syss[i] - _unks = unknowns(s) - if length(unks) != length(_unks) - valid = false - @warn("$info: connected systems $(nameof(sys)) and $(nameof(s)) have $(length(unks)) and $(length(_unks)) unknowns, cannot connect.") - continue - end - for (i, x) in enumerate(unks) - j = findfirst(isequal(x), _unks) - if j == nothing - valid = false - @warn("$info: connected systems $(nameof(sys)) and $(nameof(s)) do not have the same unknowns.") - else - aunit = safe_get_unit(x, info * string(nameof(sys)) * "#$i") - bunit = safe_get_unit(_unks[j], info * string(nameof(s)) * "#$j") - if !equivalent(aunit, bunit) - valid = false - @warn("$info: connected system unknowns $x and $(_unks[j]) have mismatched units.") - end - end - end - end - valid -end - -function validate(jump::Union{MT.VariableRateJump, - MT.ConstantRateJump}, t::Symbolic; - info::String = "") - newinfo = replace(info, "eq." => "jump") - _validate([jump.rate, 1 / t], ["rate", "1/t"], info = newinfo) && # Assuming the rate is per time units - validate(jump.affect!, info = newinfo) -end - -function validate(jump::MT.MassActionJump, t::Symbolic; info::String = "") - left_symbols = [x[1] for x in jump.reactant_stoch] #vector of pairs of symbol,int -> vector symbols - net_symbols = [x[1] for x in jump.net_stoch] - all_symbols = vcat(left_symbols, net_symbols) - allgood = _validate(all_symbols, string.(all_symbols); info) - n = sum(x -> x[2], jump.reactant_stoch, init = 0) - base_unitful = all_symbols[1] #all same, get first - allgood && _validate([jump.scaled_rates, 1 / (t * base_unitful^n)], - ["scaled_rates", "1/(t*reactants^$n))"]; info) -end - -function validate(jumps::Vector{JumpType}, t::Symbolic) - labels = ["in Mass Action Jumps,", "in Constant Rate Jumps,", "in Variable Rate Jumps,"] - majs = filter(x -> x isa MassActionJump, jumps) - crjs = filter(x -> x isa ConstantRateJump, jumps) - vrjs = filter(x -> x isa VariableRateJump, jumps) - splitjumps = [majs, crjs, vrjs] - all([validate(js, t; info) for (js, info) in zip(splitjumps, labels)]) -end - -function validate(eq::MT.Equation; info::String = "") - if typeof(eq.lhs) == Connection - _validate(eq.rhs; info) - else - _validate([eq.lhs, eq.rhs], ["left", "right"]; info) - end -end -function validate(eq::MT.Equation, - term::Union{Symbolic, Unitful.Quantity, Num}; info::String = "") - _validate([eq.lhs, eq.rhs, term], ["left", "right", "noise"]; info) -end -function validate(eq::MT.Equation, terms::Vector; info::String = "") - _validate(vcat([eq.lhs, eq.rhs], terms), - vcat(["left", "right"], "noise #" .* string.(1:length(terms))); info) -end - -""" -Returns true iff units of equations are valid. -""" -function validate(eqs::Vector; info::String = "") - all([validate(eqs[idx], info = info * " in eq. #$idx") for idx in 1:length(eqs)]) -end -function validate(eqs::Vector, noise::Vector; info::String = "") - all([validate(eqs[idx], noise[idx], info = info * " in eq. #$idx") - for idx in 1:length(eqs)]) -end -function validate(eqs::Vector, noise::Matrix; info::String = "") - all([validate(eqs[idx], noise[idx, :], info = info * " in eq. #$idx") - for idx in 1:length(eqs)]) -end -function validate(eqs::Vector, term::Symbolic; info::String = "") - all([validate(eqs[idx], term, info = info * " in eq. #$idx") for idx in 1:length(eqs)]) -end -validate(term::Symbolics.SymbolicUtils.Symbolic) = safe_get_unit(term, "") !== nothing - -""" -Throws error if units of equations are invalid. -""" -function MT.check_units(::Val{:Unitful}, eqs...) - validate(eqs...) || - throw(ValidationError("Some equations had invalid units. See warnings for details.")) -end - -end # module +# This file is kept for backward compatibility +# The UnitfulUnitCheck module is now provided by the ModelingToolkitUnitfulExt extension +# when Unitful is loaded \ No newline at end of file diff --git a/test/constants.jl b/test/constants.jl index ce5c7e6e8e..0d6376299f 100644 --- a/test/constants.jl +++ b/test/constants.jl @@ -1,7 +1,6 @@ using ModelingToolkit, OrdinaryDiffEq, Unitful using Test MT = ModelingToolkit -UMT = ModelingToolkit.UnitfulUnitCheck @constants a = 1 @test isconstant(a) @@ -25,7 +24,7 @@ simp = mtkcompile(sys) #Constant with units @constants β=1 [unit = u"m/s"] -UMT.get_unit(β) +ModelingToolkit.get_unit(β) @test MT.isconstant(β) @test !MT.istunable(β) @independent_variables t [unit = u"s"] diff --git a/test/runtests.jl b/test/runtests.jl index 47230c9539..8d37754cf7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -29,7 +29,6 @@ end @safetestset "AbstractSystem Test" include("abstractsystem.jl") @safetestset "Variable Scope Tests" include("variable_scope.jl") @safetestset "Symbolic Parameters Test" include("symbolic_parameters.jl") - @safetestset "Parsing Test" include("variable_parsing.jl") @safetestset "Simplify Test" include("simplify.jl") @safetestset "Direct Usage Test" include("direct.jl") @safetestset "System Linearity Test" include("linearity.jl") @@ -37,13 +36,11 @@ end @safetestset "Clock Test" include("clock.jl") @safetestset "ODESystem Test" include("odesystem.jl") @safetestset "Dynamic Quantities Test" include("dq_units.jl") - @safetestset "Unitful Quantities Test" include("units.jl") @safetestset "Mass Matrix Test" include("mass_matrix.jl") @safetestset "Reduction Test" include("reduction.jl") @safetestset "Split Parameters Test" include("split_parameters.jl") @safetestset "StaticArrays Test" include("static_arrays.jl") @safetestset "Components Test" include("components.jl") - @safetestset "Model Parsing Test" include("model_parsing.jl") @safetestset "Error Handling" include("error_handling.jl") @safetestset "StructuralTransformations" include("structural_transformation/runtests.jl") @safetestset "Basic transformations" include("basic_transformations.jl") @@ -58,7 +55,6 @@ end @safetestset "DAE Jacobians Test" include("dae_jacobian.jl") @safetestset "Jacobian Sparsity" include("jacobiansparsity.jl") @safetestset "Modelingtoolkitize Test" include("modelingtoolkitize.jl") - @safetestset "Constants Test" include("constants.jl") @safetestset "Parameter Dependency Test" include("parameter_dependencies.jl") @safetestset "Equation Type Accessors Test" include("equation_type_accessors.jl") @safetestset "System Accessor Functions Test" include("accessor_functions.jl") @@ -141,5 +137,9 @@ end @safetestset "BifurcationKit Extension Test" include("extensions/bifurcationkit.jl") @safetestset "InfiniteOpt Extension Test" include("extensions/test_infiniteopt.jl") @safetestset "Auto Differentiation Test" include("extensions/ad.jl") + @safetestset "Unitful Extension Test" include("units.jl") + @safetestset "Variable Parsing with Units Test" include("variable_parsing.jl") + @safetestset "Model Parsing with Units Test" include("model_parsing.jl") + @safetestset "Constants with Units Test" include("constants.jl") end end diff --git a/test/units.jl b/test/units.jl index a17dd90575..cd3264a9f8 100644 --- a/test/units.jl +++ b/test/units.jl @@ -1,53 +1,56 @@ using ModelingToolkit, OrdinaryDiffEq, JumpProcesses, Unitful using Test MT = ModelingToolkit -UMT = ModelingToolkit.UnitfulUnitCheck + +# All unit functions are now directly available from ModelingToolkit +# Extension automatically loads when Unitful is imported +unitless = 1 @independent_variables t [unit = u"ms"] @parameters τ [unit = u"ms"] γ @variables E(t) [unit = u"kJ"] P(t) [unit = u"MW"] D = Differential(t) #This is how equivalent works: -@test UMT.equivalent(u"MW", u"kJ/ms") -@test !UMT.equivalent(u"m", u"cm") -@test UMT.equivalent(UMT.get_unit(P^γ), UMT.get_unit((E / τ)^γ)) +@test ModelingToolkit.equivalent(u"MW", u"kJ/ms") +@test !ModelingToolkit.equivalent(u"m", u"cm") +@test ModelingToolkit.equivalent(ModelingToolkit.get_unit(P^γ), ModelingToolkit.get_unit((E / τ)^γ)) # Basic access -@test UMT.get_unit(t) == u"ms" -@test UMT.get_unit(E) == u"kJ" -@test UMT.get_unit(τ) == u"ms" -@test UMT.get_unit(γ) == UMT.unitless -@test UMT.get_unit(0.5) == UMT.unitless -@test UMT.get_unit(UMT.SciMLBase.NullParameters()) == UMT.unitless +@test ModelingToolkit.get_unit(t) == u"ms" +@test ModelingToolkit.get_unit(E) == u"kJ" +@test ModelingToolkit.get_unit(τ) == u"ms" +@test ModelingToolkit.get_unit(γ) == unitless +@test ModelingToolkit.get_unit(0.5) == unitless +@test ModelingToolkit.get_unit(ModelingToolkit.SciMLBase.NullParameters()) == unitless # Prohibited unit types @parameters β [unit = u"°"] α [unit = u"°C"] γ [unit = 1u"s"] -@test_throws UMT.ValidationError UMT.get_unit(β) -@test_throws UMT.ValidationError UMT.get_unit(α) -@test_throws UMT.ValidationError UMT.get_unit(γ) +@test_throws ModelingToolkit.ValidationError ModelingToolkit.get_unit(β) +@test_throws ModelingToolkit.ValidationError ModelingToolkit.get_unit(α) +@test_throws ModelingToolkit.ValidationError ModelingToolkit.get_unit(γ) # Non-trivial equivalence & operators -@test UMT.get_unit(τ^-1) == u"ms^-1" -@test UMT.equivalent(UMT.get_unit(D(E)), u"MW") -@test UMT.equivalent(UMT.get_unit(E / τ), u"MW") -@test UMT.get_unit(2 * P) == u"MW" -@test UMT.get_unit(t / τ) == UMT.unitless -@test UMT.equivalent(UMT.get_unit(P - E / τ), u"MW") -@test UMT.equivalent(UMT.get_unit(D(D(E))), u"MW/ms") -@test UMT.get_unit(ifelse(t > t, P, E / τ)) == u"MW" -@test UMT.get_unit(1.0^(t / τ)) == UMT.unitless -@test UMT.get_unit(exp(t / τ)) == UMT.unitless -@test UMT.get_unit(sin(t / τ)) == UMT.unitless -@test UMT.get_unit(sin(1 * u"rad")) == UMT.unitless -@test UMT.get_unit(t^2) == u"ms^2" +@test ModelingToolkit.get_unit(τ^-1) == u"ms^-1" +@test ModelingToolkit.equivalent(ModelingToolkit.get_unit(D(E)), u"MW") +@test ModelingToolkit.equivalent(ModelingToolkit.get_unit(E / τ), u"MW") +@test ModelingToolkit.get_unit(2 * P) == u"MW" +@test ModelingToolkit.get_unit(t / τ) == unitless +@test ModelingToolkit.equivalent(ModelingToolkit.get_unit(P - E / τ), u"MW") +@test ModelingToolkit.equivalent(ModelingToolkit.get_unit(D(D(E))), u"MW/ms") +@test ModelingToolkit.get_unit(ifelse(t > t, P, E / τ)) == u"MW" +@test ModelingToolkit.get_unit(1.0^(t / τ)) == unitless +@test ModelingToolkit.get_unit(exp(t / τ)) == unitless +@test ModelingToolkit.get_unit(sin(t / τ)) == unitless +@test ModelingToolkit.get_unit(sin(1 * u"rad")) == unitless +@test ModelingToolkit.get_unit(t^2) == u"ms^2" eqs = [D(E) ~ P - E / τ 0 ~ P] -@test UMT.validate(eqs) +@test ModelingToolkit.validate(eqs) @named sys = System(eqs, t) -@test !UMT.validate(D(D(E)) ~ P) -@test !UMT.validate(0 ~ P + E * τ) +@test !ModelingToolkit.validate(D(D(E)) ~ P) +@test !ModelingToolkit.validate(0 ~ P + E * τ) # Disabling unit validation/checks selectively @test_throws MT.ArgumentError System(eqs, t, [E, P, t], [τ], name = :sys) @@ -88,9 +91,9 @@ end good_eqs = [connect(p1, p2)] bad_eqs = [connect(p1, p2, op)] bad_length_eqs = [connect(op, lp)] -@test UMT.validate(good_eqs) -@test !UMT.validate(bad_eqs) -@test !UMT.validate(bad_length_eqs) +@test ModelingToolkit.validate(good_eqs) +@test !ModelingToolkit.validate(bad_eqs) +@test !ModelingToolkit.validate(bad_length_eqs) @named sys = System(good_eqs, t, [], []) @test_throws MT.ValidationError System(bad_eqs, t, [], []; name = :sys) @@ -130,7 +133,7 @@ noiseeqs = [0.1u"MW" 0.1u"MW" # Invalid noise matrix noiseeqs = [0.1u"MW" 0.1u"MW" 0.1u"MW" 0.1u"s"] -@test !UMT.validate(eqs, noiseeqs) +@test !ModelingToolkit.validate(eqs, noiseeqs) # Non-trivial simplifications @independent_variables t [unit = u"s"] @@ -222,7 +225,7 @@ end @test ModelingToolkit.getdefault(sys.a) ≈ [0.01, 0.03] @variables x(t) -@test ModelingToolkit.get_unit(sin(x)) == ModelingToolkit.unitless +@test ModelingToolkit.get_unit(sin(x)) == unitless @mtkmodel ExpressionParametersTest begin @parameters begin