Skip to content

Commit 1f97a0b

Browse files
committed
feat: rename to allocate_container and copy_into!
1 parent ed086e7 commit 1f97a0b

File tree

7 files changed

+110
-108
lines changed

7 files changed

+110
-108
lines changed

src/DynamicExpressions.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using DispatchDoctor: @stable, @unstable
99
include("OperatorEnum.jl")
1010
include("Node.jl")
1111
include("NodeUtils.jl")
12+
include("NodePreallocation.jl")
1213
include("Strings.jl")
1314
include("Evaluate.jl")
1415
include("EvaluateDerivative.jl")
@@ -41,11 +42,11 @@ import .ValueInterfaceModule:
4142
GraphNode,
4243
Node,
4344
copy_node,
44-
copy_node!,
4545
set_node!,
4646
tree_mapreduce,
4747
filter_map,
4848
filter_map!
49+
import .NodePreallocationModule: allocate_container, copy_into!
4950
import .NodeModule:
5051
constructorof,
5152
with_type_parameters,

src/Expression.jl

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -317,14 +317,6 @@ function extract_gradient(
317317
return extract_gradient(gradient.tree, get_tree(ex))
318318
end
319319

320-
function preallocate_expression(prototype::Expression, n::Union{Nothing,Integer}=nothing)
321-
return (; tree=preallocate_expression(DE.get_contents(prototype), n))
322-
end
323-
function DE.copy_node!(dest::NamedTuple, src::Expression)
324-
tree = DE.copy_node!(dest.tree, DE.get_contents(src))
325-
return DE.with_contents(src, tree)
326-
end
327-
328320
"""
329321
string_tree(
330322
ex::AbstractExpression,
@@ -510,4 +502,21 @@ function (ex::AbstractExpression)(
510502
return get_tree(ex)(X, get_operators(ex, operators); kws...)
511503
end
512504

505+
# We don't require users to overload this, as it's not part of the required interface.
506+
# Also, there's no way to generally do this from the required interface, so for backwards
507+
# compatibility, we just return nothing.
508+
function copy_into!(::Nothing, src::AbstractExpression)
509+
return copy(src)
510+
end
511+
function allocate_container(::AbstractExpression, ::Union{Nothing,Integer}=nothing)
512+
return nothing
513+
end
514+
function allocate_container(prototype::Expression, n::Union{Nothing,Integer}=nothing)
515+
return (; tree=allocate_container(get_contents(prototype), n))
516+
end
517+
function copy_into!(dest::NamedTuple, src::Expression)
518+
tree = copy_into!(dest.tree, get_contents(src))
519+
return with_contents(src, tree)
520+
end
521+
513522
end

src/Interfaces.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ using ..NodeModule:
1313
default_allocator,
1414
with_type_parameters,
1515
leaf_copy,
16-
leaf_copy!,
16+
leaf_copy_into!,
1717
leaf_convert,
1818
leaf_hash,
1919
leaf_equal,
2020
branch_copy,
21-
branch_copy!,
21+
branch_copy_into!,
2222
branch_convert,
2323
branch_hash,
2424
branch_equal,
@@ -264,10 +264,10 @@ function _check_leaf_copy(tree::AbstractExpressionNode)
264264
tree.degree != 0 && return true
265265
return leaf_copy(tree) isa typeof(tree)
266266
end
267-
function _check_leaf_copy!(tree::AbstractExpressionNode{T}) where {T}
267+
function _check_leaf_copy_into!(tree::AbstractExpressionNode{T}) where {T}
268268
tree.degree != 0 && return true
269269
new_leaf = constructorof(typeof(tree))(; val=zero(T))
270-
ret = leaf_copy!(new_leaf, tree)
270+
ret = leaf_copy_into!(new_leaf, tree)
271271
return new_leaf == tree && ret === new_leaf
272272
end
273273
function _check_leaf_convert(tree::AbstractExpressionNode)
@@ -292,16 +292,16 @@ function _check_branch_copy(tree::AbstractExpressionNode)
292292
return branch_copy(tree, tree.l, tree.r) isa typeof(tree)
293293
end
294294
end
295-
function _check_branch_copy!(tree::AbstractExpressionNode{T}) where {T}
295+
function _check_branch_copy_into!(tree::AbstractExpressionNode{T}) where {T}
296296
if tree.degree == 0
297297
return true
298298
end
299299
new_branch = constructorof(typeof(tree))(; val=zero(T))
300300
if tree.degree == 1
301-
ret = branch_copy!(new_branch, tree, copy(tree.l))
301+
ret = branch_copy_into!(new_branch, tree, copy(tree.l))
302302
return new_branch == tree && ret === new_branch
303303
else
304-
ret = branch_copy!(new_branch, tree, copy(tree.l), copy(tree.r))
304+
ret = branch_copy_into!(new_branch, tree, copy(tree.l), copy(tree.r))
305305
return new_branch == tree && ret === new_branch
306306
end
307307
end
@@ -373,12 +373,12 @@ ni_components = (
373373
),
374374
optional = (
375375
leaf_copy = "copies a leaf node" => _check_leaf_copy,
376-
leaf_copy! = "copies a leaf node in-place" => _check_leaf_copy!,
376+
leaf_copy_into! = "copies a leaf node in-place" => _check_leaf_copy_into!,
377377
leaf_convert = "converts a leaf node" => _check_leaf_convert,
378378
leaf_hash = "computes the hash of a leaf node" => _check_leaf_hash,
379379
leaf_equal = "checks equality of two leaf nodes" => _check_leaf_equal,
380380
branch_copy = "copies a branch node" => _check_branch_copy,
381-
branch_copy! = "copies a branch node in-place" => _check_branch_copy!,
381+
branch_copy_into! = "copies a branch node in-place" => _check_branch_copy_into!,
382382
branch_convert = "converts a branch node" => _check_branch_convert,
383383
branch_hash = "computes the hash of a branch node" => _check_branch_hash,
384384
branch_equal = "checks equality of two branch nodes" => _check_branch_equal,
@@ -419,7 +419,7 @@ ni_description = (
419419
[Arguments()]
420420
)
421421
@implements(
422-
NodeInterface{all_ni_methods_except((:leaf_copy!, :branch_copy!))},
422+
NodeInterface{all_ni_methods_except((:leaf_copy_into!, :branch_copy_into!))},
423423
GraphNode,
424424
[Arguments()]
425425
)

src/NodePreallocation.jl

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
module NodePreallocationModule
2+
3+
using ..NodeModule:
4+
AbstractExpressionNode,
5+
with_type_parameters,
6+
tree_mapreduce,
7+
leaf_copy,
8+
branch_copy,
9+
set_node!
10+
11+
"""
12+
copy_into!(dest::AbstractArray{N}, src::N) where {BS,N<:AbstractExpressionNode}
13+
14+
Copy a node, recursively copying all children nodes, in-place to an
15+
array of pre-allocated nodes. This should result in no extra allocations.
16+
"""
17+
function copy_into!(
18+
dest::AbstractArray{N}, src::N; ref::Union{Nothing,Base.RefValue{<:Integer}}=nothing
19+
) where {N<:AbstractExpressionNode}
20+
_ref = if ref === nothing
21+
Ref(0)
22+
else
23+
ref.x = 0
24+
ref
25+
end
26+
return tree_mapreduce(
27+
leaf -> leaf_copy_into!(@inbounds(dest[_ref.x += 1]), leaf),
28+
identity,
29+
((p, c::Vararg{Any,M}) where {M}) ->
30+
branch_copy_into!(@inbounds(dest[_ref.x += 1]), p, c...),
31+
src,
32+
N;
33+
break_sharing=Val(BS),
34+
)
35+
end
36+
function leaf_copy_into!(dest::N, src::N) where {N<:AbstractExpressionNode}
37+
set_node!(dest, src)
38+
return dest
39+
end
40+
function branch_copy_into!(
41+
dest::N, src::N, children::Vararg{N,M}
42+
) where {N<:AbstractExpressionNode,M}
43+
dest.degree = M
44+
dest.op = src.op
45+
dest.l = children[1]
46+
if M == 2
47+
dest.r = children[2]
48+
end
49+
return dest
50+
end
51+
52+
"""
53+
allocate_container(prototype::AbstractExpressionNode, n=nothing)
54+
55+
Preallocate an array of `n` empty nodes matching the type of `prototype`.
56+
If `n` is not provided, it will be computed from `length(prototype)`.
57+
58+
A given return value of this will be passed to `copy_into!` as the first argument,
59+
so it should be compatible.
60+
"""
61+
function allocate_container(
62+
prototype::N, n::Union{Nothing,Integer}=nothing
63+
) where {T,N<:AbstractExpressionNode{T}}
64+
num_nodes = @something(n, length(prototype))
65+
return N[with_type_parameters(N, T)() for _ in 1:num_nodes]
66+
end
67+
68+
end

src/ParametricExpression.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,12 @@ import ..NodeModule:
1414
with_type_parameters,
1515
preserve_sharing,
1616
leaf_copy,
17-
leaf_copy!,
1817
leaf_convert,
1918
leaf_hash,
2019
leaf_equal,
21-
branch_copy!,
22-
copy_node!,
23-
preallocate_expression
20+
set_node!,
21+
copy_into!,
22+
allocate_container
2423
import ..NodeUtilsModule:
2524
count_constant_nodes,
2625
index_constant_nodes,
@@ -447,16 +446,16 @@ end
447446
return node_type(; val=ex)
448447
end
449448
end
450-
function preallocate_expression(
449+
function allocate_container(
451450
prototype::ParametricExpression, n::Union{Nothing,Integer}=nothing
452451
)
453452
return (;
454-
tree=preallocate_expression(get_contents(prototype), n),
453+
tree=allocate_container(get_contents(prototype), n),
455454
parameters=similar(get_metadata(prototype).parameters),
456455
)
457456
end
458-
function copy_node!(dest::NamedTuple, src::ParametricExpression)
459-
new_tree = copy_node!(dest.tree, get_contents(src))
457+
function copy_into!(dest::NamedTuple, src::ParametricExpression)
458+
new_tree = copy_into!(dest.tree, get_contents(src))
460459
metadata = get_metadata(src)
461460
new_parameters = dest.parameters
462461
new_parameters .= metadata.parameters

src/base.jl

Lines changed: 0 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -485,81 +485,6 @@ function branch_copy(t::N, children::Vararg{Any,M}) where {T,N<:AbstractExpressi
485485
return constructorof(N)(T; op=t.op, children)
486486
end
487487

488-
# In-place versions
489-
490-
"""
491-
copy_node!(dest::AbstractArray{N}, src::N) where {BS,N<:AbstractExpressionNode}
492-
493-
Copy a node, recursively copying all children nodes, in-place to an
494-
array of pre-allocated nodes. This should result in no extra allocations.
495-
"""
496-
function copy_node!(
497-
dest::AbstractArray{N}, src::N; ref::Union{Nothing,Base.RefValue{<:Integer}}=nothing
498-
) where {N<:AbstractExpressionNode}
499-
_ref = if ref === nothing
500-
Ref(0)
501-
else
502-
ref.x = 0
503-
ref
504-
end
505-
return tree_mapreduce(
506-
leaf -> leaf_copy!(@inbounds(dest[_ref.x += 1]), leaf),
507-
identity,
508-
((p, c::Vararg{Any,M}) where {M}) ->
509-
branch_copy!(@inbounds(dest[_ref.x += 1]), p, c...),
510-
src,
511-
N;
512-
break_sharing=Val(BS),
513-
)
514-
end
515-
function leaf_copy!(dest::N, src::N) where {T,N<:AbstractExpressionNode{T}}
516-
dest.degree = 0
517-
if src.constant
518-
dest.constant = true
519-
dest.val = src.val
520-
else
521-
dest.constant = false
522-
dest.feature = src.feature
523-
end
524-
return dest
525-
end
526-
function branch_copy!(
527-
dest::N, src::N, children::Vararg{N,M}
528-
) where {T,N<:AbstractExpressionNode{T},M}
529-
dest.degree = M
530-
dest.op = src.op
531-
dest.l = children[1]
532-
if M == 2
533-
dest.r = children[2]
534-
end
535-
return dest
536-
end
537-
538-
"""
539-
preallocate_expression(prototype::AbstractExpressionNode, n=nothing)
540-
541-
Preallocate an array of empty nodes matching the type of `prototype`. If `n` is provided, use that length, otherwise use `length(prototype)`.
542-
543-
A given return value of this will be passed to `copy_node!` as the first argument,
544-
so it should be compatible.
545-
"""
546-
function preallocate_expression(
547-
prototype::N, n::Union{Nothing,Integer}=nothing
548-
) where {T,N<:AbstractExpressionNode{T}}
549-
num_nodes = @something(n, length(prototype))
550-
return N[with_type_parameters(N, T)() for _ in 1:num_nodes]
551-
end
552-
553-
function copy_node!(::Nothing, src::AbstractExpression)
554-
return copy(src)
555-
end
556-
function preallocate_expression(::AbstractExpression, ::Union{Nothing,Integer}=nothing)
557-
return nothing
558-
end
559-
# We don't require users to overload this, as it's not part of the required interface.
560-
# Also, there's no way to generally do this from the required interface, so for backwards
561-
# compatibility, we just return nothing.
562-
563488
"""
564489
copy(tree::AbstractExpressionNode; break_sharing::Val=Val(false))
565490

test/test_copy_inplace.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
@testitem "copy_node! - random trees" begin
1+
@testitem "copy_into! - random trees" begin
22
using DynamicExpressions
3-
using DynamicExpressions: copy_node!
3+
using DynamicExpressions: copy_into!
44
include("tree_gen_utils.jl")
55

66
operators = OperatorEnum(; binary_operators=[+, *, /], unary_operators=[sin, cos])
@@ -15,7 +15,7 @@
1515
orig_nodes = dest_array[(n_nodes + 1):end] # Save reference to unused nodes
1616

1717
ref = Ref(0)
18-
result = copy_node!(dest_array, tree; ref)
18+
result = copy_into!(dest_array, tree; ref)
1919

2020
@test ref[] == n_nodes # Increment once per node
2121

@@ -35,17 +35,17 @@
3535
end
3636
end
3737

38-
@testitem "copy_node! - leaf nodes" begin
38+
@testitem "copy_into! - leaf nodes" begin
3939
using DynamicExpressions
40-
using DynamicExpressions: copy_node!
40+
using DynamicExpressions: copy_into!
4141

4242
leaf_constant = Node{Float64}(; val=1.0)
4343
leaf_feature = Node{Float64}(; feature=1)
4444

4545
for leaf in [leaf_constant, leaf_feature]
4646
dest_array = [Node{Float64}() for _ in 1:1]
4747
ref = Ref(0)
48-
result = copy_node!(dest_array, leaf; ref=ref)
48+
result = copy_into!(dest_array, leaf; ref=ref)
4949
@test ref[] == 1
5050
@test result == leaf
5151
@test result === dest_array[1]

0 commit comments

Comments
 (0)