From 25a1112cd9430bdbaa36b0e07d25991e5139a596 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Wed, 11 Dec 2024 16:06:07 -0800 Subject: [PATCH 1/2] feat: create in-place copy operator --- src/Interfaces.jl | 25 +++++++++++++++++++- src/ParametricExpression.jl | 20 +++++++++++++++- src/base.jl | 47 +++++++++++++++++++++++++++++++++++++ 3 files changed, 90 insertions(+), 2 deletions(-) diff --git a/src/Interfaces.jl b/src/Interfaces.jl index 2875ef65..b950ec97 100644 --- a/src/Interfaces.jl +++ b/src/Interfaces.jl @@ -13,10 +13,12 @@ using ..NodeModule: default_allocator, with_type_parameters, leaf_copy, + leaf_copy!, leaf_convert, leaf_hash, leaf_equal, branch_copy, + branch_copy!, branch_convert, branch_hash, branch_equal, @@ -262,6 +264,12 @@ function _check_leaf_copy(tree::AbstractExpressionNode) tree.degree != 0 && return true return leaf_copy(tree) isa typeof(tree) end +function _check_leaf_copy!(tree::AbstractExpressionNode{T}) where {T} + tree.degree != 0 && return true + new_leaf = constructorof(typeof(tree))(; val=zero(T)) + ret = leaf_copy!(new_leaf, tree) + return new_leaf == tree && ret === new_leaf +end function _check_leaf_convert(tree::AbstractExpressionNode) tree.degree != 0 && return true return leaf_convert(typeof(tree), tree) isa typeof(tree) && @@ -284,6 +292,19 @@ function _check_branch_copy(tree::AbstractExpressionNode) return branch_copy(tree, tree.l, tree.r) isa typeof(tree) end end +function _check_branch_copy!(tree::AbstractExpressionNode{T}) where {T} + if tree.degree == 0 + return true + end + new_branch = constructorof(typeof(tree))(; val=zero(T)) + if tree.degree == 1 + ret = branch_copy!(new_branch, tree, copy(tree.l)) + return new_branch == tree && ret === new_branch + else + ret = branch_copy!(new_branch, tree, copy(tree.l), copy(tree.r)) + return new_branch == tree && ret === new_branch + end +end function _check_branch_convert(tree::AbstractExpressionNode) if tree.degree == 0 return true @@ -352,10 +373,12 @@ ni_components = ( ), optional = ( leaf_copy = "copies a leaf node" => _check_leaf_copy, + leaf_copy! = "copies a leaf node in-place" => _check_leaf_copy!, leaf_convert = "converts a leaf node" => _check_leaf_convert, leaf_hash = "computes the hash of a leaf node" => _check_leaf_hash, leaf_equal = "checks equality of two leaf nodes" => _check_leaf_equal, branch_copy = "copies a branch node" => _check_branch_copy, + branch_copy! = "copies a branch node in-place" => _check_branch_copy!, branch_convert = "converts a branch node" => _check_branch_convert, branch_hash = "computes the hash of a branch node" => _check_branch_hash, branch_equal = "checks equality of two branch nodes" => _check_branch_equal, @@ -396,7 +419,7 @@ ni_description = ( [Arguments()] ) @implements( - NodeInterface{all_ni_methods_except(())}, + NodeInterface{all_ni_methods_except((:leaf_copy!, :branch_copy!))}, GraphNode, [Arguments()] ) diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index 7f0ae660..60f5fb41 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -13,9 +13,11 @@ import ..NodeModule: with_type_parameters, preserve_sharing, leaf_copy, + leaf_copy!, leaf_convert, leaf_hash, - leaf_equal + leaf_equal, + branch_copy! import ..NodeUtilsModule: count_constant_nodes, index_constant_nodes, @@ -122,6 +124,22 @@ function leaf_copy(t::ParametricNode{T}) where {T} return n end end +function leaf_copy!(dest::N, src::N) where {T,N<:ParametricNode{T}} + dest.degree = 0 + if src.constant + dest.constant = true + dest.val = src.val + elseif !src.is_parameter + dest.constant = false + dest.is_parameter = false + dest.feature = src.feature + else + dest.constant = false + dest.is_parameter = true + dest.parameter = src.parameter + end + return dest +end function leaf_convert(::Type{N}, t::ParametricNode) where {T,N<:ParametricNode{T}} if t.constant return constructorof(N)(T; val=convert(T, t.val)) diff --git a/src/base.jl b/src/base.jl index 3d29a404..75ef12b6 100644 --- a/src/base.jl +++ b/src/base.jl @@ -485,6 +485,53 @@ function branch_copy(t::N, children::Vararg{Any,M}) where {T,N<:AbstractExpressi return constructorof(N)(T; op=t.op, children) end +# In-place versions + +""" + copy_node!(dest::AbstractArray{N}, src::N; break_sharing::Val{BS}=Val(false)) where {BS,N<:AbstractExpressionNode} + +Copy a node, recursively copying all children nodes, in-place to an +array of pre-allocated nodes. This should result in no extra allocations. +""" +function copy_node!( + dest::AbstractArray{N}, + src::N; + break_sharing::Val{BS}=Val(false), + ref::Base.RefValue{<:Integer}=Ref(0), +) where {BS,N<:AbstractExpressionNode} + return tree_mapreduce( + leaf -> leaf_copy!(@inbounds(dest[ref.x += 1]), leaf), + identity, + ((p, c::Vararg{Any,M}) where {M}) -> + branch_copy!(@inbounds(dest[ref.x += 1]), p, c...), + src, + N; + break_sharing=Val(BS), + ) +end +function leaf_copy!(dest::N, src::N) where {T,N<:AbstractExpressionNode{T}} + dest.degree = 0 + if src.constant + dest.constant = true + dest.val = src.val + else + dest.constant = false + dest.feature = src.feature + end + return dest +end +function branch_copy!( + dest::N, src::N, children::Vararg{N,M} +) where {T,N<:AbstractExpressionNode{T},M} + dest.degree = M + dest.op = src.op + dest.l = children[1] + if M == 2 + dest.r = children[2] + end + return dest +end + """ copy(tree::AbstractExpressionNode; break_sharing::Val=Val(false)) From 22688966ea334f985c541fd57788bbc6719d4118 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Wed, 11 Dec 2024 17:38:57 -0800 Subject: [PATCH 2/2] test: add unittests for `copy_node!` --- src/DynamicExpressions.jl | 1 + src/base.jl | 1 + test/test_copy_inplace.jl | 53 +++++++++++++++++++++++++++++++++++++++ test/unittest.jl | 1 + 4 files changed, 56 insertions(+) create mode 100644 test/test_copy_inplace.jl diff --git a/src/DynamicExpressions.jl b/src/DynamicExpressions.jl index 3259c856..8e32d899 100644 --- a/src/DynamicExpressions.jl +++ b/src/DynamicExpressions.jl @@ -41,6 +41,7 @@ import .ValueInterfaceModule: GraphNode, Node, copy_node, + copy_node!, set_node!, tree_mapreduce, filter_map, diff --git a/src/base.jl b/src/base.jl index 75ef12b6..8a656ab3 100644 --- a/src/base.jl +++ b/src/base.jl @@ -499,6 +499,7 @@ function copy_node!( break_sharing::Val{BS}=Val(false), ref::Base.RefValue{<:Integer}=Ref(0), ) where {BS,N<:AbstractExpressionNode} + ref.x = 0 return tree_mapreduce( leaf -> leaf_copy!(@inbounds(dest[ref.x += 1]), leaf), identity, diff --git a/test/test_copy_inplace.jl b/test/test_copy_inplace.jl new file mode 100644 index 00000000..4b337aec --- /dev/null +++ b/test/test_copy_inplace.jl @@ -0,0 +1,53 @@ +@testitem "copy_node! - random trees" begin + using DynamicExpressions + using DynamicExpressions: copy_node! + include("tree_gen_utils.jl") + + operators = OperatorEnum(; binary_operators=[+, *, /], unary_operators=[sin, cos]) + + for size in [1, 2, 5, 10, 20], _ in 1:10, N in (Node, ParametricNode) + tree = gen_random_tree_fixed_size(size, operators, 5, Float64, N) + n_nodes = count_nodes(tree) + @test n_nodes == size # Verify gen_random_tree_fixed_size worked + + # Make array larger than needed to test bounds: + dest_array = [N{Float64}() for _ in 1:(n_nodes + 10)] + orig_nodes = dest_array[(n_nodes + 1):end] # Save reference to unused nodes + + ref = Ref(0) + result = copy_node!(dest_array, tree; ref) + + @test ref[] == n_nodes # Increment once per node + + # Should be the same tree: + @test result == tree + @test hash(result) == hash(tree) + + # The root should be the last node in the destination array: + @test result === dest_array[n_nodes] + + # Every node in the resultant tree should be from an allocated + # node in the destination array: + @test all(n -> any(n === x for x in dest_array[1:n_nodes]), result) + + # There should be no aliasing: + @test Set(map(objectid, result)) == Set(map(objectid, dest_array[1:n_nodes])) + end +end + +@testitem "copy_node! - leaf nodes" begin + using DynamicExpressions + using DynamicExpressions: copy_node! + + leaf_constant = Node{Float64}(; val=1.0) + leaf_feature = Node{Float64}(; feature=1) + + for leaf in [leaf_constant, leaf_feature] + dest_array = [Node{Float64}() for _ in 1:1] + ref = Ref(0) + result = copy_node!(dest_array, leaf; ref=ref) + @test ref[] == 1 + @test result == leaf + @test result === dest_array[1] + end +end diff --git a/test/unittest.jl b/test/unittest.jl index b96ae144..42ae11bb 100644 --- a/test/unittest.jl +++ b/test/unittest.jl @@ -101,6 +101,7 @@ end include("test_base.jl") end include("test_base_2.jl") +include("test_copy_inplace.jl") @testitem "Test extra node fields" begin include("test_extra_node_fields.jl")