Skip to content

Commit b5b40a7

Browse files
committed
fix: various issues in preallocation interface
1 parent 95f2bdb commit b5b40a7

File tree

4 files changed

+21
-12
lines changed

4 files changed

+21
-12
lines changed

src/Interfaces.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,10 @@ using ..NodeModule:
1313
default_allocator,
1414
with_type_parameters,
1515
leaf_copy,
16-
leaf_copy_into!,
1716
leaf_convert,
1817
leaf_hash,
1918
leaf_equal,
2019
branch_copy,
21-
branch_copy_into!,
2220
branch_convert,
2321
branch_hash,
2422
branch_equal,
@@ -38,6 +36,8 @@ using ..NodeUtilsModule:
3836
has_constants,
3937
get_scalar_constants,
4038
set_scalar_constants!
39+
using ..NodePreallocationModule:
40+
copy_into!, leaf_copy_into!, branch_copy_into!, allocate_container
4141
using ..StringsModule: string_tree
4242
using ..EvaluateModule: eval_tree_array
4343
using ..EvaluateDerivativeModule: eval_grad_tree_array
@@ -96,6 +96,11 @@ function _check_with_metadata(ex::AbstractExpression)
9696
end
9797

9898
## optional
99+
function _check_copy_into!(ex::AbstractExpression)
100+
container = allocate_container(ex)
101+
prealloc_ex = copy_into!(container, ex)
102+
return container !== nothing && prealloc_ex == ex && prealloc_ex !== container
103+
end
99104
function _check_count_nodes(ex::AbstractExpression)
100105
return count_nodes(ex) isa Int64
101106
end
@@ -156,6 +161,7 @@ ei_components = (
156161
with_metadata = "returns the expression with different metadata" => _check_with_metadata,
157162
),
158163
optional = (
164+
copy_into! = "copies an expression into a preallocated container" => _check_copy_into!,
159165
count_nodes = "counts the number of nodes in the expression tree" => _check_count_nodes,
160166
count_constant_nodes = "counts the number of constant nodes in the expression tree" => _check_count_constant_nodes,
161167
count_depth = "calculates the depth of the expression tree" => _check_count_depth,
@@ -260,6 +266,11 @@ function _check_tree_mapreduce(tree::AbstractExpressionNode)
260266
end
261267

262268
## optional
269+
function _check_copy_into!(tree::AbstractExpressionNode)
270+
container = allocate_container(tree)
271+
prealloc_tree = copy_into!(container, tree)
272+
return container !== nothing && prealloc_tree == tree && prealloc_tree !== container
273+
end
263274
function _check_leaf_copy(tree::AbstractExpressionNode)
264275
tree.degree != 0 && return true
265276
return leaf_copy(tree) isa typeof(tree)
@@ -372,6 +383,7 @@ ni_components = (
372383
tree_mapreduce = "applies a function across the tree" => _check_tree_mapreduce
373384
),
374385
optional = (
386+
copy_into! = "copies a node into a preallocated container" => _check_copy_into!,
375387
leaf_copy = "copies a leaf node" => _check_leaf_copy,
376388
leaf_copy_into! = "copies a leaf node in-place" => _check_leaf_copy_into!,
377389
leaf_convert = "converts a leaf node" => _check_leaf_convert,

src/NodePreallocation.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using ..NodeModule:
99
set_node!
1010

1111
"""
12-
copy_into!(dest::AbstractArray{N}, src::N) where {BS,N<:AbstractExpressionNode}
12+
copy_into!(dest::AbstractArray{N}, src::N) where {N<:AbstractExpressionNode}
1313
1414
Copy a node, recursively copying all children nodes, in-place to an
1515
array of pre-allocated nodes. This should result in no extra allocations.
@@ -29,8 +29,7 @@ function copy_into!(
2929
((p, c::Vararg{Any,M}) where {M}) ->
3030
branch_copy_into!(@inbounds(dest[_ref.x += 1]), p, c...),
3131
src,
32-
N;
33-
break_sharing=Val(BS),
32+
N,
3433
)
3534
end
3635
function leaf_copy_into!(dest::N, src::N) where {N<:AbstractExpressionNode}

src/ParametricExpression.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@ using DispatchDoctor: @stable, @unstable
44
using ChainRulesCore: ChainRulesCore as CRC, NoTangent, @thunk
55

66
using ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum
7-
using ..NodeModule:
8-
AbstractExpressionNode, Node, tree_mapreduce, with_contents, with_metadata
9-
using ..ExpressionModule: AbstractExpression, Metadata
7+
using ..NodeModule: AbstractExpressionNode, Node, tree_mapreduce
8+
using ..ExpressionModule: AbstractExpression, Metadata, with_contents, with_metadata
109
using ..ChainRulesModule: NodeTangent
1110

1211
import ..NodeModule:
@@ -17,9 +16,8 @@ import ..NodeModule:
1716
leaf_convert,
1817
leaf_hash,
1918
leaf_equal,
20-
set_node!,
21-
copy_into!,
22-
allocate_container
19+
set_node!
20+
import ..NodePreallocationModule: copy_into!, allocate_container
2321
import ..NodeUtilsModule:
2422
count_constant_nodes,
2523
index_constant_nodes,

test/test_parametric_expression.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ end
2626
using Interfaces: test
2727

2828
ex = @parse_expression(
29-
x + y + p1 * p2,
29+
x + y + p1 * p2 + 1.5,
3030
binary_operators = [+, -, *, /],
3131
variable_names = ["x", "y"],
3232
node_type = ParametricNode,

0 commit comments

Comments
 (0)