Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/DynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import .ValueInterfaceModule:
GraphNode,
Node,
copy_node,
copy_node!,
set_node!,
tree_mapreduce,
filter_map,
Expand Down
25 changes: 24 additions & 1 deletion src/Interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ using ..NodeModule:
default_allocator,
with_type_parameters,
leaf_copy,
leaf_copy!,
leaf_convert,
leaf_hash,
leaf_equal,
branch_copy,
branch_copy!,
branch_convert,
branch_hash,
branch_equal,
Expand Down Expand Up @@ -262,6 +264,12 @@ function _check_leaf_copy(tree::AbstractExpressionNode)
tree.degree != 0 && return true
return leaf_copy(tree) isa typeof(tree)
end
function _check_leaf_copy!(tree::AbstractExpressionNode{T}) where {T}
tree.degree != 0 && return true
new_leaf = constructorof(typeof(tree))(; val=zero(T))
ret = leaf_copy!(new_leaf, tree)
return new_leaf == tree && ret === new_leaf
end
function _check_leaf_convert(tree::AbstractExpressionNode)
tree.degree != 0 && return true
return leaf_convert(typeof(tree), tree) isa typeof(tree) &&
Expand All @@ -284,6 +292,19 @@ function _check_branch_copy(tree::AbstractExpressionNode)
return branch_copy(tree, tree.l, tree.r) isa typeof(tree)
end
end
function _check_branch_copy!(tree::AbstractExpressionNode{T}) where {T}
if tree.degree == 0
return true
end
new_branch = constructorof(typeof(tree))(; val=zero(T))
if tree.degree == 1
ret = branch_copy!(new_branch, tree, copy(tree.l))
return new_branch == tree && ret === new_branch
else
ret = branch_copy!(new_branch, tree, copy(tree.l), copy(tree.r))
return new_branch == tree && ret === new_branch
end
end
function _check_branch_convert(tree::AbstractExpressionNode)
if tree.degree == 0
return true
Expand Down Expand Up @@ -352,10 +373,12 @@ ni_components = (
),
optional = (
leaf_copy = "copies a leaf node" => _check_leaf_copy,
leaf_copy! = "copies a leaf node in-place" => _check_leaf_copy!,
leaf_convert = "converts a leaf node" => _check_leaf_convert,
leaf_hash = "computes the hash of a leaf node" => _check_leaf_hash,
leaf_equal = "checks equality of two leaf nodes" => _check_leaf_equal,
branch_copy = "copies a branch node" => _check_branch_copy,
branch_copy! = "copies a branch node in-place" => _check_branch_copy!,
branch_convert = "converts a branch node" => _check_branch_convert,
branch_hash = "computes the hash of a branch node" => _check_branch_hash,
branch_equal = "checks equality of two branch nodes" => _check_branch_equal,
Expand Down Expand Up @@ -396,7 +419,7 @@ ni_description = (
[Arguments()]
)
@implements(
NodeInterface{all_ni_methods_except(())},
NodeInterface{all_ni_methods_except((:leaf_copy!, :branch_copy!))},
GraphNode,
[Arguments()]
)
Expand Down
20 changes: 19 additions & 1 deletion src/ParametricExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ import ..NodeModule:
with_type_parameters,
preserve_sharing,
leaf_copy,
leaf_copy!,
leaf_convert,
leaf_hash,
leaf_equal
leaf_equal,
branch_copy!
import ..NodeUtilsModule:
count_constant_nodes,
index_constant_nodes,
Expand Down Expand Up @@ -122,6 +124,22 @@ function leaf_copy(t::ParametricNode{T}) where {T}
return n
end
end
function leaf_copy!(dest::N, src::N) where {T,N<:ParametricNode{T}}
dest.degree = 0
if src.constant
dest.constant = true
dest.val = src.val
elseif !src.is_parameter
dest.constant = false
dest.is_parameter = false
dest.feature = src.feature
else
dest.constant = false
dest.is_parameter = true
dest.parameter = src.parameter
end
return dest
end
function leaf_convert(::Type{N}, t::ParametricNode) where {T,N<:ParametricNode{T}}
if t.constant
return constructorof(N)(T; val=convert(T, t.val))
Expand Down
48 changes: 48 additions & 0 deletions src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,54 @@ function branch_copy(t::N, children::Vararg{Any,M}) where {T,N<:AbstractExpressi
return constructorof(N)(T; op=t.op, children)
end

# In-place versions

"""
copy_node!(dest::AbstractArray{N}, src::N; break_sharing::Val{BS}=Val(false)) where {BS,N<:AbstractExpressionNode}

Copy a node, recursively copying all children nodes, in-place to an
array of pre-allocated nodes. This should result in no extra allocations.
"""
function copy_node!(
dest::AbstractArray{N},
src::N;
break_sharing::Val{BS}=Val(false),
ref::Base.RefValue{<:Integer}=Ref(0),
) where {BS,N<:AbstractExpressionNode}
ref.x = 0
return tree_mapreduce(
leaf -> leaf_copy!(@inbounds(dest[ref.x += 1]), leaf),
identity,
((p, c::Vararg{Any,M}) where {M}) ->
branch_copy!(@inbounds(dest[ref.x += 1]), p, c...),
src,
N;
break_sharing=Val(BS),
)
end
function leaf_copy!(dest::N, src::N) where {T,N<:AbstractExpressionNode{T}}
dest.degree = 0
if src.constant
dest.constant = true
dest.val = src.val
else
dest.constant = false
dest.feature = src.feature
end
return dest
end
function branch_copy!(
dest::N, src::N, children::Vararg{N,M}
) where {T,N<:AbstractExpressionNode{T},M}
dest.degree = M
dest.op = src.op
dest.l = children[1]
if M == 2
dest.r = children[2]
end
return dest
end

"""
copy(tree::AbstractExpressionNode; break_sharing::Val=Val(false))

Expand Down
53 changes: 53 additions & 0 deletions test/test_copy_inplace.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
@testitem "copy_node! - random trees" begin
using DynamicExpressions
using DynamicExpressions: copy_node!
include("tree_gen_utils.jl")

operators = OperatorEnum(; binary_operators=[+, *, /], unary_operators=[sin, cos])

for size in [1, 2, 5, 10, 20], _ in 1:10, N in (Node, ParametricNode)
tree = gen_random_tree_fixed_size(size, operators, 5, Float64, N)
n_nodes = count_nodes(tree)
@test n_nodes == size # Verify gen_random_tree_fixed_size worked

# Make array larger than needed to test bounds:
dest_array = [N{Float64}() for _ in 1:(n_nodes + 10)]
orig_nodes = dest_array[(n_nodes + 1):end] # Save reference to unused nodes

ref = Ref(0)
result = copy_node!(dest_array, tree; ref)

@test ref[] == n_nodes # Increment once per node

# Should be the same tree:
@test result == tree
@test hash(result) == hash(tree)

# The root should be the last node in the destination array:
@test result === dest_array[n_nodes]

# Every node in the resultant tree should be from an allocated
# node in the destination array:
@test all(n -> any(n === x for x in dest_array[1:n_nodes]), result)

# There should be no aliasing:
@test Set(map(objectid, result)) == Set(map(objectid, dest_array[1:n_nodes]))
end
end

@testitem "copy_node! - leaf nodes" begin
using DynamicExpressions
using DynamicExpressions: copy_node!

leaf_constant = Node{Float64}(; val=1.0)
leaf_feature = Node{Float64}(; feature=1)

for leaf in [leaf_constant, leaf_feature]
dest_array = [Node{Float64}() for _ in 1:1]
ref = Ref(0)
result = copy_node!(dest_array, leaf; ref=ref)
@test ref[] == 1
@test result == leaf
@test result === dest_array[1]
end
end
1 change: 1 addition & 0 deletions test/unittest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ end
include("test_base.jl")
end
include("test_base_2.jl")
include("test_copy_inplace.jl")

@testitem "Test extra node fields" begin
include("test_extra_node_fields.jl")
Expand Down
Loading