Skip to content

Commit 5e6482d

Browse files
authored
Merge pull request #113 from SymbolicML/reduce-allocs
in-place copy for trees
2 parents 15d9efb + 2268896 commit 5e6482d

File tree

6 files changed

+146
-2
lines changed

6 files changed

+146
-2
lines changed

src/DynamicExpressions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ import .ValueInterfaceModule:
4141
GraphNode,
4242
Node,
4343
copy_node,
44+
copy_node!,
4445
set_node!,
4546
tree_mapreduce,
4647
filter_map,

src/Interfaces.jl

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@ using ..NodeModule:
1313
default_allocator,
1414
with_type_parameters,
1515
leaf_copy,
16+
leaf_copy!,
1617
leaf_convert,
1718
leaf_hash,
1819
leaf_equal,
1920
branch_copy,
21+
branch_copy!,
2022
branch_convert,
2123
branch_hash,
2224
branch_equal,
@@ -262,6 +264,12 @@ function _check_leaf_copy(tree::AbstractExpressionNode)
262264
tree.degree != 0 && return true
263265
return leaf_copy(tree) isa typeof(tree)
264266
end
267+
function _check_leaf_copy!(tree::AbstractExpressionNode{T}) where {T}
268+
tree.degree != 0 && return true
269+
new_leaf = constructorof(typeof(tree))(; val=zero(T))
270+
ret = leaf_copy!(new_leaf, tree)
271+
return new_leaf == tree && ret === new_leaf
272+
end
265273
function _check_leaf_convert(tree::AbstractExpressionNode)
266274
tree.degree != 0 && return true
267275
return leaf_convert(typeof(tree), tree) isa typeof(tree) &&
@@ -284,6 +292,19 @@ function _check_branch_copy(tree::AbstractExpressionNode)
284292
return branch_copy(tree, tree.l, tree.r) isa typeof(tree)
285293
end
286294
end
295+
function _check_branch_copy!(tree::AbstractExpressionNode{T}) where {T}
296+
if tree.degree == 0
297+
return true
298+
end
299+
new_branch = constructorof(typeof(tree))(; val=zero(T))
300+
if tree.degree == 1
301+
ret = branch_copy!(new_branch, tree, copy(tree.l))
302+
return new_branch == tree && ret === new_branch
303+
else
304+
ret = branch_copy!(new_branch, tree, copy(tree.l), copy(tree.r))
305+
return new_branch == tree && ret === new_branch
306+
end
307+
end
287308
function _check_branch_convert(tree::AbstractExpressionNode)
288309
if tree.degree == 0
289310
return true
@@ -352,10 +373,12 @@ ni_components = (
352373
),
353374
optional = (
354375
leaf_copy = "copies a leaf node" => _check_leaf_copy,
376+
leaf_copy! = "copies a leaf node in-place" => _check_leaf_copy!,
355377
leaf_convert = "converts a leaf node" => _check_leaf_convert,
356378
leaf_hash = "computes the hash of a leaf node" => _check_leaf_hash,
357379
leaf_equal = "checks equality of two leaf nodes" => _check_leaf_equal,
358380
branch_copy = "copies a branch node" => _check_branch_copy,
381+
branch_copy! = "copies a branch node in-place" => _check_branch_copy!,
359382
branch_convert = "converts a branch node" => _check_branch_convert,
360383
branch_hash = "computes the hash of a branch node" => _check_branch_hash,
361384
branch_equal = "checks equality of two branch nodes" => _check_branch_equal,
@@ -396,7 +419,7 @@ ni_description = (
396419
[Arguments()]
397420
)
398421
@implements(
399-
NodeInterface{all_ni_methods_except(())},
422+
NodeInterface{all_ni_methods_except((:leaf_copy!, :branch_copy!))},
400423
GraphNode,
401424
[Arguments()]
402425
)

src/ParametricExpression.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@ import ..NodeModule:
1313
with_type_parameters,
1414
preserve_sharing,
1515
leaf_copy,
16+
leaf_copy!,
1617
leaf_convert,
1718
leaf_hash,
18-
leaf_equal
19+
leaf_equal,
20+
branch_copy!
1921
import ..NodeUtilsModule:
2022
count_constant_nodes,
2123
index_constant_nodes,
@@ -122,6 +124,22 @@ function leaf_copy(t::ParametricNode{T}) where {T}
122124
return n
123125
end
124126
end
127+
function leaf_copy!(dest::N, src::N) where {T,N<:ParametricNode{T}}
128+
dest.degree = 0
129+
if src.constant
130+
dest.constant = true
131+
dest.val = src.val
132+
elseif !src.is_parameter
133+
dest.constant = false
134+
dest.is_parameter = false
135+
dest.feature = src.feature
136+
else
137+
dest.constant = false
138+
dest.is_parameter = true
139+
dest.parameter = src.parameter
140+
end
141+
return dest
142+
end
125143
function leaf_convert(::Type{N}, t::ParametricNode) where {T,N<:ParametricNode{T}}
126144
if t.constant
127145
return constructorof(N)(T; val=convert(T, t.val))

src/base.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,54 @@ function branch_copy(t::N, children::Vararg{Any,M}) where {T,N<:AbstractExpressi
485485
return constructorof(N)(T; op=t.op, children)
486486
end
487487

488+
# In-place versions
489+
490+
"""
491+
copy_node!(dest::AbstractArray{N}, src::N; break_sharing::Val{BS}=Val(false)) where {BS,N<:AbstractExpressionNode}
492+
493+
Copy a node, recursively copying all children nodes, in-place to an
494+
array of pre-allocated nodes. This should result in no extra allocations.
495+
"""
496+
function copy_node!(
497+
dest::AbstractArray{N},
498+
src::N;
499+
break_sharing::Val{BS}=Val(false),
500+
ref::Base.RefValue{<:Integer}=Ref(0),
501+
) where {BS,N<:AbstractExpressionNode}
502+
ref.x = 0
503+
return tree_mapreduce(
504+
leaf -> leaf_copy!(@inbounds(dest[ref.x += 1]), leaf),
505+
identity,
506+
((p, c::Vararg{Any,M}) where {M}) ->
507+
branch_copy!(@inbounds(dest[ref.x += 1]), p, c...),
508+
src,
509+
N;
510+
break_sharing=Val(BS),
511+
)
512+
end
513+
function leaf_copy!(dest::N, src::N) where {T,N<:AbstractExpressionNode{T}}
514+
dest.degree = 0
515+
if src.constant
516+
dest.constant = true
517+
dest.val = src.val
518+
else
519+
dest.constant = false
520+
dest.feature = src.feature
521+
end
522+
return dest
523+
end
524+
function branch_copy!(
525+
dest::N, src::N, children::Vararg{N,M}
526+
) where {T,N<:AbstractExpressionNode{T},M}
527+
dest.degree = M
528+
dest.op = src.op
529+
dest.l = children[1]
530+
if M == 2
531+
dest.r = children[2]
532+
end
533+
return dest
534+
end
535+
488536
"""
489537
copy(tree::AbstractExpressionNode; break_sharing::Val=Val(false))
490538

test/test_copy_inplace.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
@testitem "copy_node! - random trees" begin
2+
using DynamicExpressions
3+
using DynamicExpressions: copy_node!
4+
include("tree_gen_utils.jl")
5+
6+
operators = OperatorEnum(; binary_operators=[+, *, /], unary_operators=[sin, cos])
7+
8+
for size in [1, 2, 5, 10, 20], _ in 1:10, N in (Node, ParametricNode)
9+
tree = gen_random_tree_fixed_size(size, operators, 5, Float64, N)
10+
n_nodes = count_nodes(tree)
11+
@test n_nodes == size # Verify gen_random_tree_fixed_size worked
12+
13+
# Make array larger than needed to test bounds:
14+
dest_array = [N{Float64}() for _ in 1:(n_nodes + 10)]
15+
orig_nodes = dest_array[(n_nodes + 1):end] # Save reference to unused nodes
16+
17+
ref = Ref(0)
18+
result = copy_node!(dest_array, tree; ref)
19+
20+
@test ref[] == n_nodes # Increment once per node
21+
22+
# Should be the same tree:
23+
@test result == tree
24+
@test hash(result) == hash(tree)
25+
26+
# The root should be the last node in the destination array:
27+
@test result === dest_array[n_nodes]
28+
29+
# Every node in the resultant tree should be from an allocated
30+
# node in the destination array:
31+
@test all(n -> any(n === x for x in dest_array[1:n_nodes]), result)
32+
33+
# There should be no aliasing:
34+
@test Set(map(objectid, result)) == Set(map(objectid, dest_array[1:n_nodes]))
35+
end
36+
end
37+
38+
@testitem "copy_node! - leaf nodes" begin
39+
using DynamicExpressions
40+
using DynamicExpressions: copy_node!
41+
42+
leaf_constant = Node{Float64}(; val=1.0)
43+
leaf_feature = Node{Float64}(; feature=1)
44+
45+
for leaf in [leaf_constant, leaf_feature]
46+
dest_array = [Node{Float64}() for _ in 1:1]
47+
ref = Ref(0)
48+
result = copy_node!(dest_array, leaf; ref=ref)
49+
@test ref[] == 1
50+
@test result == leaf
51+
@test result === dest_array[1]
52+
end
53+
end

test/unittest.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ end
101101
include("test_base.jl")
102102
end
103103
include("test_base_2.jl")
104+
include("test_copy_inplace.jl")
104105

105106
@testitem "Test extra node fields" begin
106107
include("test_extra_node_fields.jl")

0 commit comments

Comments
 (0)