@@ -5,6 +5,7 @@ using ..NodeModule: AbstractExpressionNode, Node, tree_mapreduce
55using .. ExpressionModule: AbstractExpression, Metadata, node_type
66using .. ChainRulesModule: NodeTangent
77
8+ import .. NodeModule: constructorof
89import .. ExpressionModule:
910 get_contents,
1011 get_metadata,
@@ -13,17 +14,33 @@ import ..ExpressionModule:
1314 get_variable_names,
1415 Metadata,
1516 _copy,
17+ _data,
1618 default_node_type,
1719 node_type,
1820 get_scalar_constants,
1921 set_scalar_constants!
2022
23+ abstract type AbstractStructuredExpression{
24+ T,F<: Function ,N<: AbstractExpressionNode{T} ,E<: AbstractExpression{T,N} ,D<: NamedTuple
25+ } <: AbstractExpression{T,N} end
26+
2127"""
22- StructuredExpression
28+ StructuredExpression{T,F,N,E,TS,D} <: AbstractStructuredExpression{T,F,N,E,D} <: AbstractExpression{T,N}
2329
2430This expression type allows you to combine multiple expressions
2531together in a predefined way.
2632
33+ # Parameters
34+
35+ - `T`: The numeric value type of the expressions.
36+ - `F`: The type of the structure function, which combines each expression into a single expression.
37+ - `N`: The type of the nodes inside expressions.
38+ - `E`: The type of the expressions.
39+ - `TS`: The type of the named tuple containing those inner expressions.
40+ - `D`: The type of the metadata, another named tuple.
41+
42+ # Usage
43+
2744For example, we can create two expressions, `f`, and `g`,
2845and then combine them together in a new expression, `f_plus_g`,
2946using a constructor function that simply adds them together:
@@ -56,29 +73,25 @@ which will create a new method particular to this expression type defined on tha
5673"""
5774struct StructuredExpression{
5875 T,
59- F,
60- EX<: NamedTuple ,
76+ F<: Function ,
6177 N<: AbstractExpressionNode{T} ,
6278 E<: AbstractExpression{T,N} ,
6379 TS<: NamedTuple{<:Any,<:NTuple{<:Any,E}} ,
64- D< :@NamedTuple {structure:: F , operators:: O , variable_names:: V , extra :: EX } where {O,V},
65- } <: AbstractExpression {T,N }
80+ D< :@NamedTuple {structure:: F , operators:: O , variable_names:: V } where {O,V},
81+ } <: AbstractStructuredExpression {T,F,N,E,D }
6682 trees:: TS
6783 metadata:: Metadata{D}
6884
6985 function StructuredExpression (
7086 trees:: TS , metadata:: Metadata{D}
7187 ) where {
7288 TS,
73- F,
74- EX,
75- D< :@NamedTuple {
76- structure:: F , operators:: O , variable_names:: V , extra:: EX
77- } where {O,V},
89+ F<: Function ,
90+ D< :@NamedTuple {structure:: F , operators:: O , variable_names:: V } where {O,V},
7891 }
7992 E = typeof (first (values (trees)))
8093 N = node_type (E)
81- return new {eltype(N),F,EX, N,E,TS,D} (trees, metadata)
94+ return new {eltype(N),F,N,E,TS,D} (trees, metadata)
8295 end
8396end
8497
@@ -87,65 +100,67 @@ function StructuredExpression(
87100 structure:: F ,
88101 operators:: Union{AbstractOperatorEnum,Nothing} = nothing ,
89102 variable_names:: Union{AbstractVector{<:AbstractString},Nothing} = nothing ,
90- extra... ,
91103) where {F<: Function }
92104 example_tree = first (values (trees))
93105 operators = get_operators (example_tree, operators)
94106 variable_names = get_variable_names (example_tree, variable_names)
95- metadata = (; structure, operators, variable_names, extra = (; extra ... ) )
107+ metadata = (; structure, operators, variable_names)
96108 return StructuredExpression (trees, Metadata (metadata))
97109end
98-
99- function Base. copy (e:: StructuredExpression )
110+ constructorof ( :: Type{<:StructuredExpression} ) = StructuredExpression
111+ function Base. copy (e:: AbstractStructuredExpression )
100112 ts = get_contents (e)
101113 meta = get_metadata (e)
114+ meta_inner = _data (meta)
102115 copy_ts = NamedTuple {keys(ts)} (map (copy, values (ts)))
103- return StructuredExpression (
104- copy_ts,
105- Metadata ((;
106- meta. structure,
107- operators= _copy (meta. operators),
108- variable_names= _copy (meta. variable_names),
109- extra= _copy (meta. extra),
110- )),
116+ keys_except_structure = filter (!= (:structure ), keys (meta_inner))
117+ copy_metadata = (;
118+ meta_inner. structure,
119+ NamedTuple {keys_except_structure} (
120+ map (_copy, values (meta_inner[keys_except_structure]))
121+ )... ,
111122 )
123+ return constructorof (typeof (e))(copy_ts, Metadata (copy_metadata))
112124end
113- # ! format: off
114- function get_contents (e:: StructuredExpression )
125+ function get_contents (e:: AbstractStructuredExpression )
115126 return e. trees
116127end
117- function get_metadata (e:: StructuredExpression )
128+ function get_metadata (e:: AbstractStructuredExpression )
118129 return e. metadata
119130end
120- function get_tree (e:: StructuredExpression )
121- return get_tree (e . metadata . structure (e . trees ))
131+ function get_tree (e:: AbstractStructuredExpression )
132+ return get_tree (get_metadata (e) . structure (get_contents (e) ))
122133end
123- function get_operators (e:: StructuredExpression , operators:: Union{AbstractOperatorEnum,Nothing} = nothing )
124- return operators === nothing ? e. metadata. operators : operators
134+ function get_operators (
135+ e:: AbstractStructuredExpression , operators:: Union{AbstractOperatorEnum,Nothing} = nothing
136+ )
137+ return operators === nothing ? get_metadata (e). operators : operators
125138end
126- function get_variable_names (e:: StructuredExpression , variable_names:: Union{AbstractVector{<:AbstractString},Nothing} = nothing )
127- return variable_names === nothing ? e. metadata. variable_names : variable_names
139+ function get_variable_names (
140+ e:: AbstractStructuredExpression ,
141+ variable_names:: Union{AbstractVector{<:AbstractString},Nothing} = nothing ,
142+ )
143+ return variable_names === nothing ? get_metadata (e). variable_names : variable_names
128144end
129- function get_scalar_constants (e:: StructuredExpression )
145+ function get_scalar_constants (e:: AbstractStructuredExpression )
130146 # Get constants for each inner expression
131- consts_and_refs = map (get_scalar_constants, values (e . trees ))
147+ consts_and_refs = map (get_scalar_constants, values (get_contents (e) ))
132148 flat_constants = vcat (map (first, consts_and_refs)... )
133149 # Collect info so we can put them back in the right place,
134150 # like the indexes of the constants in the flattened array
135151 refs = map (c_ref -> (; n= length (first (c_ref)), ref= last (c_ref)), consts_and_refs)
136152 return flat_constants, refs
137153end
138- function set_scalar_constants! (e:: StructuredExpression , constants, refs)
154+ function set_scalar_constants! (e:: AbstractStructuredExpression , constants, refs)
139155 cursor = Ref (1 )
140- foreach (values (e . trees ), refs) do tree, r
156+ foreach (values (get_contents (e) ), refs) do tree, r
141157 n = r. n
142158 i = cursor[]
143- c = constants[i: (i+ n - 1 )]
159+ c = constants[i: (i + n - 1 )]
144160 set_scalar_constants! (tree, c, r. ref)
145161 cursor[] += n
146162 end
147163 return e
148164end
149- # ! format: on
150165
151166end
0 commit comments