Skip to content

Commit eecc9da

Browse files
committed
feat: add preallocation for abstract structured expression
1 parent b5b40a7 commit eecc9da

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

src/Expression.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import ..NodeUtilsModule:
1919
count_scalar_constants,
2020
get_scalar_constants,
2121
set_scalar_constants!
22+
import ..NodePreallocationModule: copy_into!, allocate_container
2223
import ..EvaluateModule: eval_tree_array, differentiable_eval_tree_array
2324
import ..EvaluateDerivativeModule: eval_grad_tree_array
2425
import ..EvaluationHelpersModule: _grad_evaluator

src/StructuredExpression.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@ using ..ExpressionModule: AbstractExpression, Metadata, node_type
66
using ..ChainRulesModule: NodeTangent
77

88
import ..NodeModule: constructorof
9+
import ..NodePreallocationModule: copy_into!, allocate_container
910
import ..ExpressionModule:
1011
get_contents,
1112
get_metadata,
1213
get_tree,
1314
get_operators,
1415
get_variable_names,
16+
with_contents,
1517
Metadata,
1618
_copy,
1719
_data,
@@ -164,4 +166,16 @@ function set_scalar_constants!(e::AbstractStructuredExpression, constants, refs)
164166
return e
165167
end
166168

169+
function allocate_container(
170+
e::AbstractStructuredExpression, n::Union{Nothing,Integer}=nothing
171+
)
172+
ts = get_contents(e)
173+
return (; trees=NamedTuple{keys(ts)}(map(t -> allocate_container(t, n), values(ts))))
174+
end
175+
function copy_into!(dest::NamedTuple, src::AbstractStructuredExpression)
176+
ts = get_contents(src)
177+
new_contents = NamedTuple{keys(ts)}(map(copy_into!, values(dest.trees), values(ts)))
178+
return with_contents(src, new_contents)
179+
end
180+
167181
end

0 commit comments

Comments
 (0)