diff --git a/Project.toml b/Project.toml index 62fdb620..b9f7e630 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DynamicExpressions" uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b" authors = ["MilesCranmer "] -version = "1.1.0" +version = "1.2.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/DynamicExpressions.jl b/src/DynamicExpressions.jl index 385b5079..7e6079f9 100644 --- a/src/DynamicExpressions.jl +++ b/src/DynamicExpressions.jl @@ -93,6 +93,7 @@ import .ExpressionModule: import .ParseModule: parse_leaf @reexport import .ParametricExpressionModule: ParametricExpression, ParametricNode @reexport import .StructuredExpressionModule: StructuredExpression +import .StructuredExpressionModule: AbstractStructuredExpression @stable default_mode = "disable" begin include("Interfaces.jl") diff --git a/src/StructuredExpression.jl b/src/StructuredExpression.jl index f2e68cf7..0282a63f 100644 --- a/src/StructuredExpression.jl +++ b/src/StructuredExpression.jl @@ -5,6 +5,7 @@ using ..NodeModule: AbstractExpressionNode, Node, tree_mapreduce using ..ExpressionModule: AbstractExpression, Metadata, node_type using ..ChainRulesModule: NodeTangent +import ..NodeModule: constructorof import ..ExpressionModule: get_contents, get_metadata, @@ -13,17 +14,33 @@ import ..ExpressionModule: get_variable_names, Metadata, _copy, + _data, default_node_type, node_type, get_scalar_constants, set_scalar_constants! +abstract type AbstractStructuredExpression{ + T,F<:Function,N<:AbstractExpressionNode{T},E<:AbstractExpression{T,N},D<:NamedTuple +} <: AbstractExpression{T,N} end + """ - StructuredExpression + StructuredExpression{T,F,N,E,TS,D} <: AbstractStructuredExpression{T,F,N,E,D} <: AbstractExpression{T,N} This expression type allows you to combine multiple expressions together in a predefined way. +# Parameters + +- `T`: The numeric value type of the expressions. +- `F`: The type of the structure function, which combines each expression into a single expression. +- `N`: The type of the nodes inside expressions. +- `E`: The type of the expressions. +- `TS`: The type of the named tuple containing those inner expressions. +- `D`: The type of the metadata, another named tuple. + +# Usage + For example, we can create two expressions, `f`, and `g`, and then combine them together in a new expression, `f_plus_g`, using a constructor function that simply adds them together: @@ -56,13 +73,12 @@ which will create a new method particular to this expression type defined on tha """ struct StructuredExpression{ T, - F, - EX<:NamedTuple, + F<:Function, N<:AbstractExpressionNode{T}, E<:AbstractExpression{T,N}, TS<:NamedTuple{<:Any,<:NTuple{<:Any,E}}, - D<:@NamedTuple{structure::F, operators::O, variable_names::V, extra::EX} where {O,V}, -} <: AbstractExpression{T,N} + D<:@NamedTuple{structure::F, operators::O, variable_names::V} where {O,V}, +} <: AbstractStructuredExpression{T,F,N,E,D} trees::TS metadata::Metadata{D} @@ -70,15 +86,12 @@ struct StructuredExpression{ trees::TS, metadata::Metadata{D} ) where { TS, - F, - EX, - D<:@NamedTuple{ - structure::F, operators::O, variable_names::V, extra::EX - } where {O,V}, + F<:Function, + D<:@NamedTuple{structure::F, operators::O, variable_names::V} where {O,V}, } E = typeof(first(values(trees))) N = node_type(E) - return new{eltype(N),F,EX,N,E,TS,D}(trees, metadata) + return new{eltype(N),F,N,E,TS,D}(trees, metadata) end end @@ -87,65 +100,67 @@ function StructuredExpression( structure::F, operators::Union{AbstractOperatorEnum,Nothing}=nothing, variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing, - extra..., ) where {F<:Function} example_tree = first(values(trees)) operators = get_operators(example_tree, operators) variable_names = get_variable_names(example_tree, variable_names) - metadata = (; structure, operators, variable_names, extra=(; extra...)) + metadata = (; structure, operators, variable_names) return StructuredExpression(trees, Metadata(metadata)) end - -function Base.copy(e::StructuredExpression) +constructorof(::Type{<:StructuredExpression}) = StructuredExpression +function Base.copy(e::AbstractStructuredExpression) ts = get_contents(e) meta = get_metadata(e) + meta_inner = _data(meta) copy_ts = NamedTuple{keys(ts)}(map(copy, values(ts))) - return StructuredExpression( - copy_ts, - Metadata((; - meta.structure, - operators=_copy(meta.operators), - variable_names=_copy(meta.variable_names), - extra=_copy(meta.extra), - )), + keys_except_structure = filter(!=(:structure), keys(meta_inner)) + copy_metadata = (; + meta_inner.structure, + NamedTuple{keys_except_structure}( + map(_copy, values(meta_inner[keys_except_structure])) + )..., ) + return constructorof(typeof(e))(copy_ts, Metadata(copy_metadata)) end -#! format: off -function get_contents(e::StructuredExpression) +function get_contents(e::AbstractStructuredExpression) return e.trees end -function get_metadata(e::StructuredExpression) +function get_metadata(e::AbstractStructuredExpression) return e.metadata end -function get_tree(e::StructuredExpression) - return get_tree(e.metadata.structure(e.trees)) +function get_tree(e::AbstractStructuredExpression) + return get_tree(get_metadata(e).structure(get_contents(e))) end -function get_operators(e::StructuredExpression, operators::Union{AbstractOperatorEnum,Nothing}=nothing) - return operators === nothing ? e.metadata.operators : operators +function get_operators( + e::AbstractStructuredExpression, operators::Union{AbstractOperatorEnum,Nothing}=nothing +) + return operators === nothing ? get_metadata(e).operators : operators end -function get_variable_names(e::StructuredExpression, variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing) - return variable_names === nothing ? e.metadata.variable_names : variable_names +function get_variable_names( + e::AbstractStructuredExpression, + variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing, +) + return variable_names === nothing ? get_metadata(e).variable_names : variable_names end -function get_scalar_constants(e::StructuredExpression) +function get_scalar_constants(e::AbstractStructuredExpression) # Get constants for each inner expression - consts_and_refs = map(get_scalar_constants, values(e.trees)) + consts_and_refs = map(get_scalar_constants, values(get_contents(e))) flat_constants = vcat(map(first, consts_and_refs)...) # Collect info so we can put them back in the right place, # like the indexes of the constants in the flattened array refs = map(c_ref -> (; n=length(first(c_ref)), ref=last(c_ref)), consts_and_refs) return flat_constants, refs end -function set_scalar_constants!(e::StructuredExpression, constants, refs) +function set_scalar_constants!(e::AbstractStructuredExpression, constants, refs) cursor = Ref(1) - foreach(values(e.trees), refs) do tree, r + foreach(values(get_contents(e)), refs) do tree, r n = r.n i = cursor[] - c = constants[i:(i+n-1)] + c = constants[i:(i + n - 1)] set_scalar_constants!(tree, c, r.ref) cursor[] += n end return e end -#! format: on end diff --git a/test/test_structured_expression.jl b/test/test_structured_expression.jl index e256bad8..a5bc1dcc 100644 --- a/test/test_structured_expression.jl +++ b/test/test_structured_expression.jl @@ -63,14 +63,7 @@ end f = parse_expression(:(x * x - cos(2.5f0 * y + -0.5f0)); kws...) g = parse_expression(:(exp(-(y * y))); kws...) - c = [1] - ex = StructuredExpression((; f, g); structure=my_factory, a=c) - - @test ex.metadata.extra.a[] == 1 - @test ex.metadata.extra.a === c - - # Should copy everything down to the metadata: - @test copy(ex).metadata.extra.a !== c + ex = StructuredExpression((; f, g); structure=my_factory) h(_) = 1 h(::StructuredExpression{<:Any,typeof(my_factory)}) = 2