Skip to content

Commit 96d6609

Browse files
authored
Merge pull request #27 from SymbolicML/tree-map
Define functions in Base to treat `Node` as collection
2 parents 571d8b8 + 5073e35 commit 96d6609

15 files changed

+614
-283
lines changed

Project.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,27 @@
11
name = "DynamicExpressions"
22
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
33
authors = ["MilesCranmer <[email protected]>"]
4-
version = "0.7.0"
4+
version = "0.8.0"
55

66
[deps]
7+
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
78
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
89
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
910
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
11+
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1012
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1113
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1214
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
13-
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1415
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
1516
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
1617
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1718

1819
[compat]
20+
Compat = "3.37, 4"
1921
LoopVectorization = "0.12"
2022
MacroTools = "0.4, 0.5"
21-
Reexport = "1"
2223
PrecompileTools = "1"
24+
Reexport = "1"
2325
SymbolicUtils = "0.19, ^1.0.5"
2426
Zygote = "0.6"
2527
julia = "1.6"

benchmark/benchmarks.jl

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using DynamicExpressions, BenchmarkTools, Random
2-
using DynamicExpressions: copy_node
2+
using DynamicExpressions.EquationUtilsModule: is_constant
33

44
include("benchmark_utils.jl")
55

@@ -73,30 +73,54 @@ end
7373
PACKAGE_VERSION < v"0.7.0" && return :(copy_node(t; preserve_topology=preserve_sharing))
7474
return :(copy_node(t; preserve_sharing=preserve_sharing))
7575
end
76+
@generated function get_set_constants!(tree)
77+
!(@isdefined set_constants!) && return :(set_constants(tree, get_constants(tree)))
78+
return :(set_constants!(tree, get_constants(tree)))
79+
end
7680
#! format: on
7781

82+
f_tree_op(f::F, tree, operators) where {F} = f(tree, operators)
83+
f_tree_op(f::F, tree) where {F} = f(tree)
84+
7885
function benchmark_utilities()
7986
suite = BenchmarkGroup()
87+
88+
all_funcs = (
89+
:copy,
90+
:convert,
91+
:simplify_tree,
92+
:combine_operators,
93+
:count_nodes,
94+
:count_depth,
95+
:count_constants,
96+
:has_constants,
97+
:has_operators,
98+
:is_constant,
99+
:get_set_constants!,
100+
:index_constants,
101+
)
102+
80103
operators = OperatorEnum(; binary_operators=[+, -, /, *], unary_operators=[cos, exp])
81-
for func_k in ("copy", "convert", "simplify_tree", "combine_operators")
104+
105+
for func_k in all_funcs
82106
suite[func_k] = let s = BenchmarkGroup()
83-
for k in ("break_sharing", "preserve_sharing")
84-
k == "preserve_sharing" &&
85-
func_k in ("simplify_tree", "combine_operators") &&
86-
continue
107+
for k in (:break_sharing, :preserve_sharing)
108+
k == :preserve_sharing && !(func_k in (:copy, :convert)) && continue
87109

88-
f = if func_k == "copy"
89-
tree -> _copy_node(tree; preserve_sharing=(k == "preserve_sharing"))
90-
elseif func_k == "convert"
110+
f = if func_k == :copy
111+
tree -> _copy_node(tree; preserve_sharing=(k == :preserve_sharing))
112+
elseif func_k == :convert
91113
tree -> _convert(
92114
Node{Float64},
93115
tree;
94-
preserve_sharing=(k == "preserve_sharing"),
116+
preserve_sharing=(k == :preserve_sharing),
95117
)
96-
elseif func_k == "simplify_tree"
97-
tree -> simplify_tree(tree, operators)
98-
elseif func_k == "combine_operators"
99-
tree -> combine_operators(tree, operators)
118+
elseif func_k in (:simplify_tree, :combine_operators)
119+
g = getfield(@__MODULE__, func_k)
120+
tree -> f_tree_op(g, tree, operators)
121+
else
122+
g = getfield(@__MODULE__, func_k)
123+
tree -> f_tree_op(g, tree)
100124
end
101125

102126
#! format: off

src/DynamicExpressions.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,18 @@ include("SimplifyEquation.jl")
1212
include("OperatorEnumConstruction.jl")
1313

1414
using Reexport
15-
@reexport import .EquationModule: Node, string_tree, print_tree, copy_node, set_node!
15+
@reexport import .EquationModule:
16+
Node, string_tree, print_tree, copy_node, set_node!, tree_mapreduce, filter_map
1617
@reexport import .EquationUtilsModule:
1718
count_nodes,
18-
count_nodes_with_stack,
1919
count_constants,
2020
count_depth,
2121
NodeIndex,
2222
index_constants,
2323
has_operators,
2424
has_constants,
2525
get_constants,
26-
set_constants
26+
set_constants!
2727
@reexport import .OperatorEnumModule: AbstractOperatorEnum
2828
@reexport import .OperatorEnumConstructionModule:
2929
OperatorEnum, GenericOperatorEnum, @extend_operators
@@ -34,6 +34,8 @@ using Reexport
3434
@reexport import .SimplifyEquationModule: combine_operators, simplify_tree
3535
@reexport import .EvaluationHelpersModule
3636

37+
include("deprecated.jl")
38+
3739
import TOML: parsefile
3840

3941
const PACKAGE_VERSION = let

src/Equation.jl

Lines changed: 2 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module EquationModule
22

33
import ..OperatorEnumModule: AbstractOperatorEnum
4-
import ..UtilsModule: @generate_idmap, @use_idmap
4+
import ..UtilsModule: @memoize_on, @with_memoize
55

66
const DEFAULT_NODE_TYPE = Float32
77

@@ -62,51 +62,7 @@ mutable struct Node{T}
6262
end
6363
################################################################################
6464

65-
"""
66-
convert(::Type{Node{T1}}, n::Node{T2}) where {T1,T2}
67-
68-
Convert a `Node{T2}` to a `Node{T1}`.
69-
This will recursively convert all children nodes to `Node{T1}`,
70-
using `convert(T1, tree.val)` at constant nodes.
71-
72-
# Arguments
73-
- `::Type{Node{T1}}`: Type to convert to.
74-
- `tree::Node{T2}`: Node to convert.
75-
"""
76-
function Base.convert(
77-
::Type{Node{T1}}, tree::Node{T2}; preserve_sharing::Bool=false
78-
) where {T1,T2}
79-
if T1 == T2
80-
return tree
81-
end
82-
if preserve_sharing
83-
@use_idmap(_convert(Node{T1}, tree), IdDict{Node{T2},Node{T1}}())
84-
else
85-
_convert(Node{T1}, tree)
86-
end
87-
end
88-
89-
@generate_idmap tree function _convert(::Type{Node{T1}}, tree::Node{T2}) where {T1,T2}
90-
if tree.degree == 0
91-
if tree.constant
92-
val = tree.val::T2
93-
if !(T2 <: T1)
94-
# e.g., we don't want to convert Float32 to Union{Float32,Vector{Float32}}!
95-
val = convert(T1, val)
96-
end
97-
Node(T1, 0, tree.constant, val)
98-
else
99-
Node(T1, 0, tree.constant, nothing, tree.feature)
100-
end
101-
elseif tree.degree == 1
102-
l = _convert(Node{T1}, tree.l)
103-
Node(1, tree.constant, nothing, tree.feature, tree.op, l)
104-
else
105-
l = _convert(Node{T1}, tree.l)
106-
r = _convert(Node{T1}, tree.r)
107-
Node(2, tree.constant, nothing, tree.feature, tree.op, l, r)
108-
end
109-
end
65+
include("base.jl")
11066

11167
"""
11268
Node([::Type{T}]; val=nothing, feature::Int=nothing) where {T}
@@ -224,45 +180,6 @@ function set_node!(tree::Node{T}, new_tree::Node{T}) where {T}
224180
return nothing
225181
end
226182

227-
"""
228-
copy_node(tree::Node; preserve_sharing::Bool=false)
229-
230-
Copy a node, recursively copying all children nodes.
231-
This is more efficient than the built-in copy.
232-
With `preserve_sharing=true`, this will also
233-
preserve linkage between a node and
234-
multiple parents, whereas without, this would create
235-
duplicate child node copies.
236-
237-
id_map is a map from `objectid(tree)` to `copy(tree)`.
238-
We check against the map before making a new copy; otherwise
239-
we can simply reference the existing copy.
240-
[Thanks to Ted Hopp.](https://stackoverflow.com/questions/49285475/how-to-copy-a-full-non-binary-tree-including-loops)
241-
242-
Note that this will *not* preserve loops in graphs.
243-
"""
244-
function copy_node(tree::Node{T}; preserve_sharing::Bool=false)::Node{T} where {T}
245-
if preserve_sharing
246-
@use_idmap(_copy_node(tree), IdDict{Node{T},Node{T}}())
247-
else
248-
_copy_node(tree)
249-
end
250-
end
251-
252-
@generate_idmap tree function _copy_node(tree::Node{T})::Node{T} where {T}
253-
if tree.degree == 0
254-
if tree.constant
255-
Node(; val=copy(tree.val::T))
256-
else
257-
Node(T; feature=copy(tree.feature))
258-
end
259-
elseif tree.degree == 1
260-
Node(copy(tree.op), _copy_node(tree.l))
261-
else
262-
Node(copy(tree.op), _copy_node(tree.l), _copy_node(tree.r))
263-
end
264-
end
265-
266183
const OP_NAMES = Dict(
267184
"safe_log" => "log",
268185
"safe_log2" => "log2",
@@ -363,51 +280,4 @@ function print_tree(
363280
return println(string_tree(tree, operators; varMap=varMap))
364281
end
365282

366-
function Base.hash(tree::Node{T})::UInt where {T}
367-
if tree.degree == 0
368-
if tree.constant
369-
# tree.val used.
370-
return hash((0, tree.val::T))
371-
else
372-
# tree.feature used.
373-
return hash((1, tree.feature))
374-
end
375-
elseif tree.degree == 1
376-
return hash((1, tree.op, hash(tree.l)))
377-
else
378-
return hash((2, tree.op, hash(tree.l), hash(tree.r)))
379-
end
380-
end
381-
382-
function is_equal(a::Node{T}, b::Node{T})::Bool where {T}
383-
if a.degree == 0
384-
b.degree != 0 && return false
385-
if a.constant
386-
!(b.constant) && return false
387-
return a.val::T == b.val::T
388-
else
389-
b.constant && return false
390-
return a.feature == b.feature
391-
end
392-
elseif a.degree == 1
393-
b.degree != 1 && return false
394-
a.op != b.op && return false
395-
return is_equal(a.l, b.l)
396-
else
397-
b.degree != 2 && return false
398-
a.op != b.op && return false
399-
return is_equal(a.l, b.l) && is_equal(a.r, b.r)
400-
end
401-
end
402-
403-
function Base.:(==)(a::Node{T}, b::Node{T})::Bool where {T}
404-
return is_equal(a, b)
405-
end
406-
407-
function Base.:(==)(a::Node{T1}, b::Node{T2})::Bool where {T1,T2}
408-
T = promote_type(T1, T2)
409-
# TODO: Should also have preserve_sharing check...
410-
return is_equal(convert(Node{T}, a), convert(Node{T}, b))
411-
end
412-
413283
end

0 commit comments

Comments
 (0)