Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions src/Expression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,24 @@ import ..SimplifyModule: combine_operators, simplify_tree!
struct Metadata{NT<:NamedTuple}
_data::NT
end
_data(x::Metadata) = getfield(x, :_data)
unpack_metadata(x::Metadata) = getfield(x, :_data)

Base.propertynames(x::Metadata) = propertynames(_data(x))
@unstable @inline Base.getproperty(x::Metadata, f::Symbol) = getproperty(_data(x), f)
Base.show(io::IO, x::Metadata) = print(io, "Metadata(", _data(x), ")")
Base.propertynames(x::Metadata) = propertynames(unpack_metadata(x))
@unstable @inline function Base.getproperty(x::Metadata, f::Symbol)
return getproperty(unpack_metadata(x), f)
end
Base.show(io::IO, x::Metadata) = print(io, "Metadata(", unpack_metadata(x), ")")
@inline _copy(x) = copy(x)
@inline _copy(x::NamedTuple) = copy_named_tuple(x)
@inline _copy(x::Nothing) = nothing
@inline _copy(::Nothing) = nothing
@inline function copy_named_tuple(nt::NamedTuple)
return NamedTuple{keys(nt)}(map(_copy, values(nt)))
end
@inline function Base.copy(metadata::Metadata)
return Metadata(_copy(_data(metadata)))
return Metadata(_copy(unpack_metadata(metadata)))
end
@inline Base.:(==)(x::Metadata, y::Metadata) = _data(x) == _data(y)
@inline Base.hash(x::Metadata, h::UInt) = hash(_data(x), h)
@inline Base.:(==)(x::Metadata, y::Metadata) = unpack_metadata(x) == unpack_metadata(y)
@inline Base.hash(x::Metadata, h::UInt) = hash(unpack_metadata(x), h)

"""
AbstractExpression{T,N}
Expand Down Expand Up @@ -216,7 +218,9 @@ end
Create a new expression based on `ex` but with a different `metadata`.
"""
function with_metadata(ex::AbstractExpression; metadata...)
return with_metadata(ex, Metadata((; metadata...)))
return with_metadata(
ex, Metadata((; unpack_metadata(get_metadata(ex))..., metadata...))
)
end
function with_metadata(ex::AbstractExpression, metadata::Metadata)
return constructorof(typeof(ex))(get_contents(ex), metadata)
Expand Down Expand Up @@ -246,7 +250,13 @@ end
function get_variable_names(
ex::Expression, variable_names::Union{Nothing,AbstractVector{<:AbstractString}}=nothing
)
return variable_names === nothing ? ex.metadata.variable_names : variable_names
return if variable_names !== nothing
variable_names
elseif hasproperty(ex.metadata, :variable_names)
ex.metadata.variable_names
else
nothing
end
end
function get_tree(ex::Expression)
return ex.tree
Expand Down
30 changes: 15 additions & 15 deletions src/ParametricExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ using ChainRulesCore: ChainRulesCore as CRC, NoTangent, @thunk

using ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum
using ..NodeModule: AbstractExpressionNode, Node, tree_mapreduce
using ..ExpressionModule: AbstractExpression, Metadata, with_contents, with_metadata
using ..ExpressionModule:
AbstractExpression, Metadata, with_contents, with_metadata, unpack_metadata
using ..ChainRulesModule: NodeTangent

import ..NodeModule:
Expand Down Expand Up @@ -63,7 +64,6 @@ mutable struct ParametricNode{T} <: AbstractExpressionNode{T}
return n
end
end
@inline _data(x::Metadata) = getfield(x, :_data)

"""
ParametricExpression{T,N<:ParametricNode{T},D<:NamedTuple} <: AbstractExpression{T,N}
Expand All @@ -79,15 +79,17 @@ struct ParametricExpression{
metadata::Metadata{D}

function ParametricExpression(tree::ParametricNode, metadata::Metadata)
return new{eltype(tree),typeof(tree),typeof(_data(metadata))}(tree, metadata)
return new{eltype(tree),typeof(tree),typeof(unpack_metadata(metadata))}(
tree, metadata
)
end
end
function ParametricExpression(
tree::ParametricNode{T1};
operators::Union{AbstractOperatorEnum,Nothing},
variable_names,
variable_names=nothing,
parameters::AbstractMatrix{T2},
parameter_names,
parameter_names=nothing,
) where {T1,T2}
if !isnothing(parameter_names)
@assert size(parameters, 1) == length(parameter_names)
Expand Down Expand Up @@ -200,18 +202,16 @@ function get_variable_names(
ex::ParametricExpression,
variable_names::Union{Nothing,AbstractVector{<:AbstractString}}=nothing,
)
return variable_names === nothing ? ex.metadata.variable_names : variable_names
return if variable_names !== nothing
variable_names
elseif hasproperty(ex.metadata, :variable_names)
ex.metadata.variable_names
else
nothing
end
end
@inline _copy_with_nothing(x) = copy(x)
@inline _copy_with_nothing(::Nothing) = nothing
function Base.copy(ex::ParametricExpression; break_sharing::Val=Val(false))
return ParametricExpression(
copy(ex.tree; break_sharing=break_sharing);
operators=_copy_with_nothing(ex.metadata.operators),
variable_names=_copy_with_nothing(ex.metadata.variable_names),
parameters=_copy_with_nothing(ex.metadata.parameters),
parameter_names=_copy_with_nothing(ex.metadata.parameter_names),
)
return ParametricExpression(copy(ex.tree; break_sharing), copy(ex.metadata))
end
###############################################################################

Expand Down
4 changes: 2 additions & 2 deletions src/Strings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ function string_tree(
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
f_variable::F1=string_variable,
f_constant::F2=string_constant,
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
variable_names=nothing,
pretty::Union{Bool,Nothing}=nothing, # Not used, but can be used by other types
# Deprecated
raw::Union{Bool,Nothing}=nothing,
Expand Down Expand Up @@ -190,7 +190,7 @@ for io in ((), (:(io::IO),))
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
f_variable::F1=string_variable,
f_constant::F2=string_constant,
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
variable_names=nothing,
pretty::Union{Bool,Nothing}=nothing, # Not used, but can be used by other types
# Deprecated
raw::Union{Bool,Nothing}=nothing,
Expand Down
12 changes: 9 additions & 3 deletions src/StructuredExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import ..ExpressionModule:
with_contents,
Metadata,
_copy,
_data,
unpack_metadata,
default_node_type,
node_type,
get_scalar_constants,
Expand Down Expand Up @@ -114,7 +114,7 @@ constructorof(::Type{<:StructuredExpression}) = StructuredExpression
function Base.copy(e::AbstractStructuredExpression)
ts = get_contents(e)
meta = get_metadata(e)
meta_inner = _data(meta)
meta_inner = unpack_metadata(meta)
copy_ts = NamedTuple{keys(ts)}(map(copy, values(ts)))
keys_except_structure = filter(!=(:structure), keys(meta_inner))
copy_metadata = (;
Expand Down Expand Up @@ -143,7 +143,13 @@ function get_variable_names(
e::AbstractStructuredExpression,
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
)
return variable_names === nothing ? get_metadata(e).variable_names : variable_names
return if variable_names !== nothing
variable_names
elseif hasproperty(get_metadata(e), :variable_names)
get_metadata(e).variable_names
else
nothing
end
end
function get_scalar_constants(e::AbstractStructuredExpression)
# Get constants for each inner expression
Expand Down
43 changes: 42 additions & 1 deletion test/test_expressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ end

@testitem "Miscellaneous expression calls" begin
using DynamicExpressions
using DynamicExpressions: get_tree, get_operators
using DynamicExpressions: get_tree, get_operators, default_node_type

ex = @parse_expression(x1 + 1.5, binary_operators = [+], variable_names = ["x1"])
@test DynamicExpressions.ExpressionModule.node_type(ex) <: Node
Expand All @@ -278,6 +278,24 @@ end

tree = get_tree(ex)
@test_throws ArgumentError get_operators(tree, nothing)

# We can also define expressions without variable names, and it should work
operators = OperatorEnum(; binary_operators=[+])
for E in (Expression, ParametricExpression)
N = default_node_type(E)
kws = (; operators)
if E === ParametricExpression
kws = (; kws..., parameters=Matrix{Float64}(undef, 0, 0))
end
x1, x2 = (E(N(Float64; feature=i); kws...) for i in 1:2)
x1000 = E(N(Float64; feature=1000); kws...)
@test string(x1 + x2 + x1000) == "(x1 + x2) + x1000"
# And also with structured expressions
x1 = StructuredExpression(
(; x1, x2, x1000); operators, structure=nt -> nt.x1 + nt.x2 + nt.x1000
)
@test string(x1) == "(x1 + x2) + x1000"
end
end

@testitem "Expression Literate examples" begin
Expand Down Expand Up @@ -413,3 +431,26 @@ end

#literate_end
end

@testitem "Expression with_metadata partial updates" begin
using DynamicExpressions
using DynamicExpressions: get_operators, get_metadata, with_metadata, get_variable_names

# Create an expression with initial metadata
ex = @parse_expression(
x1 + 1.5,
operators = OperatorEnum(; binary_operators=[+, *]),
variable_names = ["x1"]
)

# Update only the variable_names, keeping the original operators
new_ex = with_metadata(ex; variable_names=["y1"])
@test get_variable_names(new_ex, nothing) == ["y1"]
@test get_operators(new_ex, nothing) == get_operators(ex, nothing)

# Update only the operators, keeping the original variable_names
new_operators = OperatorEnum(; binary_operators=[+])
new_ex2 = with_metadata(ex; operators=new_operators)
@test get_variable_names(new_ex2, nothing) == ["x1"]
@test get_operators(new_ex2, nothing) == new_operators
end
Loading