Skip to content

Commit 25a1112

Browse files
committed
feat: create in-place copy operator
1 parent 4a52c37 commit 25a1112

File tree

3 files changed

+90
-2
lines changed

3 files changed

+90
-2
lines changed

src/Interfaces.jl

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@ using ..NodeModule:
1313
default_allocator,
1414
with_type_parameters,
1515
leaf_copy,
16+
leaf_copy!,
1617
leaf_convert,
1718
leaf_hash,
1819
leaf_equal,
1920
branch_copy,
21+
branch_copy!,
2022
branch_convert,
2123
branch_hash,
2224
branch_equal,
@@ -262,6 +264,12 @@ function _check_leaf_copy(tree::AbstractExpressionNode)
262264
tree.degree != 0 && return true
263265
return leaf_copy(tree) isa typeof(tree)
264266
end
267+
function _check_leaf_copy!(tree::AbstractExpressionNode{T}) where {T}
268+
tree.degree != 0 && return true
269+
new_leaf = constructorof(typeof(tree))(; val=zero(T))
270+
ret = leaf_copy!(new_leaf, tree)
271+
return new_leaf == tree && ret === new_leaf
272+
end
265273
function _check_leaf_convert(tree::AbstractExpressionNode)
266274
tree.degree != 0 && return true
267275
return leaf_convert(typeof(tree), tree) isa typeof(tree) &&
@@ -284,6 +292,19 @@ function _check_branch_copy(tree::AbstractExpressionNode)
284292
return branch_copy(tree, tree.l, tree.r) isa typeof(tree)
285293
end
286294
end
295+
function _check_branch_copy!(tree::AbstractExpressionNode{T}) where {T}
296+
if tree.degree == 0
297+
return true
298+
end
299+
new_branch = constructorof(typeof(tree))(; val=zero(T))
300+
if tree.degree == 1
301+
ret = branch_copy!(new_branch, tree, copy(tree.l))
302+
return new_branch == tree && ret === new_branch
303+
else
304+
ret = branch_copy!(new_branch, tree, copy(tree.l), copy(tree.r))
305+
return new_branch == tree && ret === new_branch
306+
end
307+
end
287308
function _check_branch_convert(tree::AbstractExpressionNode)
288309
if tree.degree == 0
289310
return true
@@ -352,10 +373,12 @@ ni_components = (
352373
),
353374
optional = (
354375
leaf_copy = "copies a leaf node" => _check_leaf_copy,
376+
leaf_copy! = "copies a leaf node in-place" => _check_leaf_copy!,
355377
leaf_convert = "converts a leaf node" => _check_leaf_convert,
356378
leaf_hash = "computes the hash of a leaf node" => _check_leaf_hash,
357379
leaf_equal = "checks equality of two leaf nodes" => _check_leaf_equal,
358380
branch_copy = "copies a branch node" => _check_branch_copy,
381+
branch_copy! = "copies a branch node in-place" => _check_branch_copy!,
359382
branch_convert = "converts a branch node" => _check_branch_convert,
360383
branch_hash = "computes the hash of a branch node" => _check_branch_hash,
361384
branch_equal = "checks equality of two branch nodes" => _check_branch_equal,
@@ -396,7 +419,7 @@ ni_description = (
396419
[Arguments()]
397420
)
398421
@implements(
399-
NodeInterface{all_ni_methods_except(())},
422+
NodeInterface{all_ni_methods_except((:leaf_copy!, :branch_copy!))},
400423
GraphNode,
401424
[Arguments()]
402425
)

src/ParametricExpression.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@ import ..NodeModule:
1313
with_type_parameters,
1414
preserve_sharing,
1515
leaf_copy,
16+
leaf_copy!,
1617
leaf_convert,
1718
leaf_hash,
18-
leaf_equal
19+
leaf_equal,
20+
branch_copy!
1921
import ..NodeUtilsModule:
2022
count_constant_nodes,
2123
index_constant_nodes,
@@ -122,6 +124,22 @@ function leaf_copy(t::ParametricNode{T}) where {T}
122124
return n
123125
end
124126
end
127+
function leaf_copy!(dest::N, src::N) where {T,N<:ParametricNode{T}}
128+
dest.degree = 0
129+
if src.constant
130+
dest.constant = true
131+
dest.val = src.val
132+
elseif !src.is_parameter
133+
dest.constant = false
134+
dest.is_parameter = false
135+
dest.feature = src.feature
136+
else
137+
dest.constant = false
138+
dest.is_parameter = true
139+
dest.parameter = src.parameter
140+
end
141+
return dest
142+
end
125143
function leaf_convert(::Type{N}, t::ParametricNode) where {T,N<:ParametricNode{T}}
126144
if t.constant
127145
return constructorof(N)(T; val=convert(T, t.val))

src/base.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,53 @@ function branch_copy(t::N, children::Vararg{Any,M}) where {T,N<:AbstractExpressi
485485
return constructorof(N)(T; op=t.op, children)
486486
end
487487

488+
# In-place versions
489+
490+
"""
491+
copy_node!(dest::AbstractArray{N}, src::N; break_sharing::Val{BS}=Val(false)) where {BS,N<:AbstractExpressionNode}
492+
493+
Copy a node, recursively copying all children nodes, in-place to an
494+
array of pre-allocated nodes. This should result in no extra allocations.
495+
"""
496+
function copy_node!(
497+
dest::AbstractArray{N},
498+
src::N;
499+
break_sharing::Val{BS}=Val(false),
500+
ref::Base.RefValue{<:Integer}=Ref(0),
501+
) where {BS,N<:AbstractExpressionNode}
502+
return tree_mapreduce(
503+
leaf -> leaf_copy!(@inbounds(dest[ref.x += 1]), leaf),
504+
identity,
505+
((p, c::Vararg{Any,M}) where {M}) ->
506+
branch_copy!(@inbounds(dest[ref.x += 1]), p, c...),
507+
src,
508+
N;
509+
break_sharing=Val(BS),
510+
)
511+
end
512+
function leaf_copy!(dest::N, src::N) where {T,N<:AbstractExpressionNode{T}}
513+
dest.degree = 0
514+
if src.constant
515+
dest.constant = true
516+
dest.val = src.val
517+
else
518+
dest.constant = false
519+
dest.feature = src.feature
520+
end
521+
return dest
522+
end
523+
function branch_copy!(
524+
dest::N, src::N, children::Vararg{N,M}
525+
) where {T,N<:AbstractExpressionNode{T},M}
526+
dest.degree = M
527+
dest.op = src.op
528+
dest.l = children[1]
529+
if M == 2
530+
dest.r = children[2]
531+
end
532+
return dest
533+
end
534+
488535
"""
489536
copy(tree::AbstractExpressionNode; break_sharing::Val=Val(false))
490537

0 commit comments

Comments
 (0)