diff --git a/src/Expression.jl b/src/Expression.jl index 68fd8818..9e7325a6 100644 --- a/src/Expression.jl +++ b/src/Expression.jl @@ -31,22 +31,24 @@ import ..SimplifyModule: combine_operators, simplify_tree! struct Metadata{NT<:NamedTuple} _data::NT end -_data(x::Metadata) = getfield(x, :_data) +unpack_metadata(x::Metadata) = getfield(x, :_data) -Base.propertynames(x::Metadata) = propertynames(_data(x)) -@unstable @inline Base.getproperty(x::Metadata, f::Symbol) = getproperty(_data(x), f) -Base.show(io::IO, x::Metadata) = print(io, "Metadata(", _data(x), ")") +Base.propertynames(x::Metadata) = propertynames(unpack_metadata(x)) +@unstable @inline function Base.getproperty(x::Metadata, f::Symbol) + return getproperty(unpack_metadata(x), f) +end +Base.show(io::IO, x::Metadata) = print(io, "Metadata(", unpack_metadata(x), ")") @inline _copy(x) = copy(x) @inline _copy(x::NamedTuple) = copy_named_tuple(x) -@inline _copy(x::Nothing) = nothing +@inline _copy(::Nothing) = nothing @inline function copy_named_tuple(nt::NamedTuple) return NamedTuple{keys(nt)}(map(_copy, values(nt))) end @inline function Base.copy(metadata::Metadata) - return Metadata(_copy(_data(metadata))) + return Metadata(_copy(unpack_metadata(metadata))) end -@inline Base.:(==)(x::Metadata, y::Metadata) = _data(x) == _data(y) -@inline Base.hash(x::Metadata, h::UInt) = hash(_data(x), h) +@inline Base.:(==)(x::Metadata, y::Metadata) = unpack_metadata(x) == unpack_metadata(y) +@inline Base.hash(x::Metadata, h::UInt) = hash(unpack_metadata(x), h) """ AbstractExpression{T,N} @@ -216,7 +218,9 @@ end Create a new expression based on `ex` but with a different `metadata`. """ function with_metadata(ex::AbstractExpression; metadata...) - return with_metadata(ex, Metadata((; metadata...))) + return with_metadata( + ex, Metadata((; unpack_metadata(get_metadata(ex))..., metadata...)) + ) end function with_metadata(ex::AbstractExpression, metadata::Metadata) return constructorof(typeof(ex))(get_contents(ex), metadata) @@ -246,7 +250,13 @@ end function get_variable_names( ex::Expression, variable_names::Union{Nothing,AbstractVector{<:AbstractString}}=nothing ) - return variable_names === nothing ? ex.metadata.variable_names : variable_names + return if variable_names !== nothing + variable_names + elseif hasproperty(ex.metadata, :variable_names) + ex.metadata.variable_names + else + nothing + end end function get_tree(ex::Expression) return ex.tree diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index 16d27254..854e28d7 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -5,7 +5,8 @@ using ChainRulesCore: ChainRulesCore as CRC, NoTangent, @thunk using ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum using ..NodeModule: AbstractExpressionNode, Node, tree_mapreduce -using ..ExpressionModule: AbstractExpression, Metadata, with_contents, with_metadata +using ..ExpressionModule: + AbstractExpression, Metadata, with_contents, with_metadata, unpack_metadata using ..ChainRulesModule: NodeTangent import ..NodeModule: @@ -63,7 +64,6 @@ mutable struct ParametricNode{T} <: AbstractExpressionNode{T} return n end end -@inline _data(x::Metadata) = getfield(x, :_data) """ ParametricExpression{T,N<:ParametricNode{T},D<:NamedTuple} <: AbstractExpression{T,N} @@ -79,15 +79,17 @@ struct ParametricExpression{ metadata::Metadata{D} function ParametricExpression(tree::ParametricNode, metadata::Metadata) - return new{eltype(tree),typeof(tree),typeof(_data(metadata))}(tree, metadata) + return new{eltype(tree),typeof(tree),typeof(unpack_metadata(metadata))}( + tree, metadata + ) end end function ParametricExpression( tree::ParametricNode{T1}; operators::Union{AbstractOperatorEnum,Nothing}, - variable_names, + variable_names=nothing, parameters::AbstractMatrix{T2}, - parameter_names, + parameter_names=nothing, ) where {T1,T2} if !isnothing(parameter_names) @assert size(parameters, 1) == length(parameter_names) @@ -200,18 +202,16 @@ function get_variable_names( ex::ParametricExpression, variable_names::Union{Nothing,AbstractVector{<:AbstractString}}=nothing, ) - return variable_names === nothing ? ex.metadata.variable_names : variable_names + return if variable_names !== nothing + variable_names + elseif hasproperty(ex.metadata, :variable_names) + ex.metadata.variable_names + else + nothing + end end -@inline _copy_with_nothing(x) = copy(x) -@inline _copy_with_nothing(::Nothing) = nothing function Base.copy(ex::ParametricExpression; break_sharing::Val=Val(false)) - return ParametricExpression( - copy(ex.tree; break_sharing=break_sharing); - operators=_copy_with_nothing(ex.metadata.operators), - variable_names=_copy_with_nothing(ex.metadata.variable_names), - parameters=_copy_with_nothing(ex.metadata.parameters), - parameter_names=_copy_with_nothing(ex.metadata.parameter_names), - ) + return ParametricExpression(copy(ex.tree; break_sharing), copy(ex.metadata)) end ############################################################################### diff --git a/src/Strings.jl b/src/Strings.jl index c4a3f0dd..f4dfa204 100644 --- a/src/Strings.jl +++ b/src/Strings.jl @@ -139,7 +139,7 @@ function string_tree( operators::Union{AbstractOperatorEnum,Nothing}=nothing; f_variable::F1=string_variable, f_constant::F2=string_constant, - variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing, + variable_names=nothing, pretty::Union{Bool,Nothing}=nothing, # Not used, but can be used by other types # Deprecated raw::Union{Bool,Nothing}=nothing, @@ -190,7 +190,7 @@ for io in ((), (:(io::IO),)) operators::Union{AbstractOperatorEnum,Nothing}=nothing; f_variable::F1=string_variable, f_constant::F2=string_constant, - variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing, + variable_names=nothing, pretty::Union{Bool,Nothing}=nothing, # Not used, but can be used by other types # Deprecated raw::Union{Bool,Nothing}=nothing, diff --git a/src/StructuredExpression.jl b/src/StructuredExpression.jl index 963da0e4..8b1550c3 100644 --- a/src/StructuredExpression.jl +++ b/src/StructuredExpression.jl @@ -16,7 +16,7 @@ import ..ExpressionModule: with_contents, Metadata, _copy, - _data, + unpack_metadata, default_node_type, node_type, get_scalar_constants, @@ -114,7 +114,7 @@ constructorof(::Type{<:StructuredExpression}) = StructuredExpression function Base.copy(e::AbstractStructuredExpression) ts = get_contents(e) meta = get_metadata(e) - meta_inner = _data(meta) + meta_inner = unpack_metadata(meta) copy_ts = NamedTuple{keys(ts)}(map(copy, values(ts))) keys_except_structure = filter(!=(:structure), keys(meta_inner)) copy_metadata = (; @@ -143,7 +143,13 @@ function get_variable_names( e::AbstractStructuredExpression, variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing, ) - return variable_names === nothing ? get_metadata(e).variable_names : variable_names + return if variable_names !== nothing + variable_names + elseif hasproperty(get_metadata(e), :variable_names) + get_metadata(e).variable_names + else + nothing + end end function get_scalar_constants(e::AbstractStructuredExpression) # Get constants for each inner expression diff --git a/test/test_expressions.jl b/test/test_expressions.jl index 63f9a9a2..54448001 100644 --- a/test/test_expressions.jl +++ b/test/test_expressions.jl @@ -269,7 +269,7 @@ end @testitem "Miscellaneous expression calls" begin using DynamicExpressions - using DynamicExpressions: get_tree, get_operators + using DynamicExpressions: get_tree, get_operators, default_node_type ex = @parse_expression(x1 + 1.5, binary_operators = [+], variable_names = ["x1"]) @test DynamicExpressions.ExpressionModule.node_type(ex) <: Node @@ -278,6 +278,24 @@ end tree = get_tree(ex) @test_throws ArgumentError get_operators(tree, nothing) + + # We can also define expressions without variable names, and it should work + operators = OperatorEnum(; binary_operators=[+]) + for E in (Expression, ParametricExpression) + N = default_node_type(E) + kws = (; operators) + if E === ParametricExpression + kws = (; kws..., parameters=Matrix{Float64}(undef, 0, 0)) + end + x1, x2 = (E(N(Float64; feature=i); kws...) for i in 1:2) + x1000 = E(N(Float64; feature=1000); kws...) + @test string(x1 + x2 + x1000) == "(x1 + x2) + x1000" + # And also with structured expressions + x1 = StructuredExpression( + (; x1, x2, x1000); operators, structure=nt -> nt.x1 + nt.x2 + nt.x1000 + ) + @test string(x1) == "(x1 + x2) + x1000" + end end @testitem "Expression Literate examples" begin @@ -413,3 +431,26 @@ end #literate_end end + +@testitem "Expression with_metadata partial updates" begin + using DynamicExpressions + using DynamicExpressions: get_operators, get_metadata, with_metadata, get_variable_names + + # Create an expression with initial metadata + ex = @parse_expression( + x1 + 1.5, + operators = OperatorEnum(; binary_operators=[+, *]), + variable_names = ["x1"] + ) + + # Update only the variable_names, keeping the original operators + new_ex = with_metadata(ex; variable_names=["y1"]) + @test get_variable_names(new_ex, nothing) == ["y1"] + @test get_operators(new_ex, nothing) == get_operators(ex, nothing) + + # Update only the operators, keeping the original variable_names + new_operators = OperatorEnum(; binary_operators=[+]) + new_ex2 = with_metadata(ex; operators=new_operators) + @test get_variable_names(new_ex2, nothing) == ["x1"] + @test get_operators(new_ex2, nothing) == new_operators +end