|
1 | 1 | module EquationModule |
2 | 2 |
|
3 | 3 | import ..OperatorEnumModule: AbstractOperatorEnum |
4 | | -import ..UtilsModule: @generate_idmap, @use_idmap |
| 4 | +import ..UtilsModule: @memoize_on, @with_memoize |
5 | 5 |
|
6 | 6 | const DEFAULT_NODE_TYPE = Float32 |
7 | 7 |
|
@@ -62,51 +62,7 @@ mutable struct Node{T} |
62 | 62 | end |
63 | 63 | ################################################################################ |
64 | 64 |
|
65 | | -""" |
66 | | - convert(::Type{Node{T1}}, n::Node{T2}) where {T1,T2} |
67 | | -
|
68 | | -Convert a `Node{T2}` to a `Node{T1}`. |
69 | | -This will recursively convert all children nodes to `Node{T1}`, |
70 | | -using `convert(T1, tree.val)` at constant nodes. |
71 | | -
|
72 | | -# Arguments |
73 | | -- `::Type{Node{T1}}`: Type to convert to. |
74 | | -- `tree::Node{T2}`: Node to convert. |
75 | | -""" |
76 | | -function Base.convert( |
77 | | - ::Type{Node{T1}}, tree::Node{T2}; preserve_sharing::Bool=false |
78 | | -) where {T1,T2} |
79 | | - if T1 == T2 |
80 | | - return tree |
81 | | - end |
82 | | - if preserve_sharing |
83 | | - @use_idmap(_convert(Node{T1}, tree), IdDict{Node{T2},Node{T1}}()) |
84 | | - else |
85 | | - _convert(Node{T1}, tree) |
86 | | - end |
87 | | -end |
88 | | - |
89 | | -@generate_idmap tree function _convert(::Type{Node{T1}}, tree::Node{T2}) where {T1,T2} |
90 | | - if tree.degree == 0 |
91 | | - if tree.constant |
92 | | - val = tree.val::T2 |
93 | | - if !(T2 <: T1) |
94 | | - # e.g., we don't want to convert Float32 to Union{Float32,Vector{Float32}}! |
95 | | - val = convert(T1, val) |
96 | | - end |
97 | | - Node(T1, 0, tree.constant, val) |
98 | | - else |
99 | | - Node(T1, 0, tree.constant, nothing, tree.feature) |
100 | | - end |
101 | | - elseif tree.degree == 1 |
102 | | - l = _convert(Node{T1}, tree.l) |
103 | | - Node(1, tree.constant, nothing, tree.feature, tree.op, l) |
104 | | - else |
105 | | - l = _convert(Node{T1}, tree.l) |
106 | | - r = _convert(Node{T1}, tree.r) |
107 | | - Node(2, tree.constant, nothing, tree.feature, tree.op, l, r) |
108 | | - end |
109 | | -end |
| 65 | +include("base.jl") |
110 | 66 |
|
111 | 67 | """ |
112 | 68 | Node([::Type{T}]; val=nothing, feature::Int=nothing) where {T} |
@@ -224,45 +180,6 @@ function set_node!(tree::Node{T}, new_tree::Node{T}) where {T} |
224 | 180 | return nothing |
225 | 181 | end |
226 | 182 |
|
227 | | -""" |
228 | | - copy_node(tree::Node; preserve_sharing::Bool=false) |
229 | | -
|
230 | | -Copy a node, recursively copying all children nodes. |
231 | | -This is more efficient than the built-in copy. |
232 | | -With `preserve_sharing=true`, this will also |
233 | | -preserve linkage between a node and |
234 | | -multiple parents, whereas without, this would create |
235 | | -duplicate child node copies. |
236 | | -
|
237 | | -id_map is a map from `objectid(tree)` to `copy(tree)`. |
238 | | -We check against the map before making a new copy; otherwise |
239 | | -we can simply reference the existing copy. |
240 | | -[Thanks to Ted Hopp.](https://stackoverflow.com/questions/49285475/how-to-copy-a-full-non-binary-tree-including-loops) |
241 | | -
|
242 | | -Note that this will *not* preserve loops in graphs. |
243 | | -""" |
244 | | -function copy_node(tree::Node{T}; preserve_sharing::Bool=false)::Node{T} where {T} |
245 | | - if preserve_sharing |
246 | | - @use_idmap(_copy_node(tree), IdDict{Node{T},Node{T}}()) |
247 | | - else |
248 | | - _copy_node(tree) |
249 | | - end |
250 | | -end |
251 | | - |
252 | | -@generate_idmap tree function _copy_node(tree::Node{T})::Node{T} where {T} |
253 | | - if tree.degree == 0 |
254 | | - if tree.constant |
255 | | - Node(; val=copy(tree.val::T)) |
256 | | - else |
257 | | - Node(T; feature=copy(tree.feature)) |
258 | | - end |
259 | | - elseif tree.degree == 1 |
260 | | - Node(copy(tree.op), _copy_node(tree.l)) |
261 | | - else |
262 | | - Node(copy(tree.op), _copy_node(tree.l), _copy_node(tree.r)) |
263 | | - end |
264 | | -end |
265 | | - |
266 | 183 | const OP_NAMES = Dict( |
267 | 184 | "safe_log" => "log", |
268 | 185 | "safe_log2" => "log2", |
@@ -363,51 +280,4 @@ function print_tree( |
363 | 280 | return println(string_tree(tree, operators; varMap=varMap)) |
364 | 281 | end |
365 | 282 |
|
366 | | -function Base.hash(tree::Node{T})::UInt where {T} |
367 | | - if tree.degree == 0 |
368 | | - if tree.constant |
369 | | - # tree.val used. |
370 | | - return hash((0, tree.val::T)) |
371 | | - else |
372 | | - # tree.feature used. |
373 | | - return hash((1, tree.feature)) |
374 | | - end |
375 | | - elseif tree.degree == 1 |
376 | | - return hash((1, tree.op, hash(tree.l))) |
377 | | - else |
378 | | - return hash((2, tree.op, hash(tree.l), hash(tree.r))) |
379 | | - end |
380 | | -end |
381 | | - |
382 | | -function is_equal(a::Node{T}, b::Node{T})::Bool where {T} |
383 | | - if a.degree == 0 |
384 | | - b.degree != 0 && return false |
385 | | - if a.constant |
386 | | - !(b.constant) && return false |
387 | | - return a.val::T == b.val::T |
388 | | - else |
389 | | - b.constant && return false |
390 | | - return a.feature == b.feature |
391 | | - end |
392 | | - elseif a.degree == 1 |
393 | | - b.degree != 1 && return false |
394 | | - a.op != b.op && return false |
395 | | - return is_equal(a.l, b.l) |
396 | | - else |
397 | | - b.degree != 2 && return false |
398 | | - a.op != b.op && return false |
399 | | - return is_equal(a.l, b.l) && is_equal(a.r, b.r) |
400 | | - end |
401 | | -end |
402 | | - |
403 | | -function Base.:(==)(a::Node{T}, b::Node{T})::Bool where {T} |
404 | | - return is_equal(a, b) |
405 | | -end |
406 | | - |
407 | | -function Base.:(==)(a::Node{T1}, b::Node{T2})::Bool where {T1,T2} |
408 | | - T = promote_type(T1, T2) |
409 | | - # TODO: Should also have preserve_sharing check... |
410 | | - return is_equal(convert(Node{T}, a), convert(Node{T}, b)) |
411 | | -end |
412 | | - |
413 | 283 | end |
0 commit comments