Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DynamicExpressions"
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
authors = ["MilesCranmer <[email protected]>"]
version = "1.1.0"
version = "1.2.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
1 change: 1 addition & 0 deletions src/DynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ import .ExpressionModule:
import .ParseModule: parse_leaf
@reexport import .ParametricExpressionModule: ParametricExpression, ParametricNode
@reexport import .StructuredExpressionModule: StructuredExpression
import .StructuredExpressionModule: AbstractStructuredExpression

@stable default_mode = "disable" begin
include("Interfaces.jl")
Expand Down
91 changes: 53 additions & 38 deletions src/StructuredExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using ..NodeModule: AbstractExpressionNode, Node, tree_mapreduce
using ..ExpressionModule: AbstractExpression, Metadata, node_type
using ..ChainRulesModule: NodeTangent

import ..NodeModule: constructorof
import ..ExpressionModule:
get_contents,
get_metadata,
Expand All @@ -13,17 +14,33 @@ import ..ExpressionModule:
get_variable_names,
Metadata,
_copy,
_data,
default_node_type,
node_type,
get_scalar_constants,
set_scalar_constants!

abstract type AbstractStructuredExpression{
T,F<:Function,N<:AbstractExpressionNode{T},E<:AbstractExpression{T,N},D<:NamedTuple
} <: AbstractExpression{T,N} end

"""
StructuredExpression
StructuredExpression{T,F,N,E,TS,D} <: AbstractStructuredExpression{T,F,N,E,D} <: AbstractExpression{T,N}

This expression type allows you to combine multiple expressions
together in a predefined way.

# Parameters

- `T`: The numeric value type of the expressions.
- `F`: The type of the structure function, which combines each expression into a single expression.
- `N`: The type of the nodes inside expressions.
- `E`: The type of the expressions.
- `TS`: The type of the named tuple containing those inner expressions.
- `D`: The type of the metadata, another named tuple.

# Usage

For example, we can create two expressions, `f`, and `g`,
and then combine them together in a new expression, `f_plus_g`,
using a constructor function that simply adds them together:
Expand Down Expand Up @@ -56,29 +73,25 @@ which will create a new method particular to this expression type defined on tha
"""
struct StructuredExpression{
T,
F,
EX<:NamedTuple,
F<:Function,
N<:AbstractExpressionNode{T},
E<:AbstractExpression{T,N},
TS<:NamedTuple{<:Any,<:NTuple{<:Any,E}},
D<:@NamedTuple{structure::F, operators::O, variable_names::V, extra::EX} where {O,V},
} <: AbstractExpression{T,N}
D<:@NamedTuple{structure::F, operators::O, variable_names::V} where {O,V},
} <: AbstractStructuredExpression{T,F,N,E,D}
trees::TS
metadata::Metadata{D}

function StructuredExpression(
trees::TS, metadata::Metadata{D}
) where {
TS,
F,
EX,
D<:@NamedTuple{
structure::F, operators::O, variable_names::V, extra::EX
} where {O,V},
F<:Function,
D<:@NamedTuple{structure::F, operators::O, variable_names::V} where {O,V},
}
E = typeof(first(values(trees)))
N = node_type(E)
return new{eltype(N),F,EX,N,E,TS,D}(trees, metadata)
return new{eltype(N),F,N,E,TS,D}(trees, metadata)
end
end

Expand All @@ -87,65 +100,67 @@ function StructuredExpression(
structure::F,
operators::Union{AbstractOperatorEnum,Nothing}=nothing,
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
extra...,
) where {F<:Function}
example_tree = first(values(trees))
operators = get_operators(example_tree, operators)
variable_names = get_variable_names(example_tree, variable_names)
metadata = (; structure, operators, variable_names, extra=(; extra...))
metadata = (; structure, operators, variable_names)
return StructuredExpression(trees, Metadata(metadata))
end

function Base.copy(e::StructuredExpression)
constructorof(::Type{<:StructuredExpression}) = StructuredExpression
function Base.copy(e::AbstractStructuredExpression)
ts = get_contents(e)
meta = get_metadata(e)
meta_inner = _data(meta)
copy_ts = NamedTuple{keys(ts)}(map(copy, values(ts)))
return StructuredExpression(
copy_ts,
Metadata((;
meta.structure,
operators=_copy(meta.operators),
variable_names=_copy(meta.variable_names),
extra=_copy(meta.extra),
)),
keys_except_structure = filter(!=(:structure), keys(meta_inner))
copy_metadata = (;
meta_inner.structure,
NamedTuple{keys_except_structure}(
map(_copy, values(meta_inner[keys_except_structure]))
)...,
)
return constructorof(typeof(e))(copy_ts, Metadata(copy_metadata))
end
#! format: off
function get_contents(e::StructuredExpression)
function get_contents(e::AbstractStructuredExpression)
return e.trees
end
function get_metadata(e::StructuredExpression)
function get_metadata(e::AbstractStructuredExpression)
return e.metadata
end
function get_tree(e::StructuredExpression)
return get_tree(e.metadata.structure(e.trees))
function get_tree(e::AbstractStructuredExpression)
return get_tree(get_metadata(e).structure(get_contents(e)))
end
function get_operators(e::StructuredExpression, operators::Union{AbstractOperatorEnum,Nothing}=nothing)
return operators === nothing ? e.metadata.operators : operators
function get_operators(
e::AbstractStructuredExpression, operators::Union{AbstractOperatorEnum,Nothing}=nothing
)
return operators === nothing ? get_metadata(e).operators : operators
end
function get_variable_names(e::StructuredExpression, variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing)
return variable_names === nothing ? e.metadata.variable_names : variable_names
function get_variable_names(
e::AbstractStructuredExpression,
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
)
return variable_names === nothing ? get_metadata(e).variable_names : variable_names
end
function get_scalar_constants(e::StructuredExpression)
function get_scalar_constants(e::AbstractStructuredExpression)
# Get constants for each inner expression
consts_and_refs = map(get_scalar_constants, values(e.trees))
consts_and_refs = map(get_scalar_constants, values(get_contents(e)))
flat_constants = vcat(map(first, consts_and_refs)...)
# Collect info so we can put them back in the right place,
# like the indexes of the constants in the flattened array
refs = map(c_ref -> (; n=length(first(c_ref)), ref=last(c_ref)), consts_and_refs)
return flat_constants, refs
end
function set_scalar_constants!(e::StructuredExpression, constants, refs)
function set_scalar_constants!(e::AbstractStructuredExpression, constants, refs)
cursor = Ref(1)
foreach(values(e.trees), refs) do tree, r
foreach(values(get_contents(e)), refs) do tree, r
n = r.n
i = cursor[]
c = constants[i:(i+n-1)]
c = constants[i:(i + n - 1)]
set_scalar_constants!(tree, c, r.ref)
cursor[] += n
end
return e
end
#! format: on

end
9 changes: 1 addition & 8 deletions test/test_structured_expression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,7 @@ end
f = parse_expression(:(x * x - cos(2.5f0 * y + -0.5f0)); kws...)
g = parse_expression(:(exp(-(y * y))); kws...)

c = [1]
ex = StructuredExpression((; f, g); structure=my_factory, a=c)

@test ex.metadata.extra.a[] == 1
@test ex.metadata.extra.a === c

# Should copy everything down to the metadata:
@test copy(ex).metadata.extra.a !== c
ex = StructuredExpression((; f, g); structure=my_factory)

h(_) = 1
h(::StructuredExpression{<:Any,typeof(my_factory)}) = 2
Expand Down
Loading