Skip to content

Commit ed086e7

Browse files
committed
feat: add preallocation utilities for expression
1 parent e7955d6 commit ed086e7

File tree

3 files changed

+71
-11
lines changed

3 files changed

+71
-11
lines changed

src/Expression.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,14 @@ function extract_gradient(
317317
return extract_gradient(gradient.tree, get_tree(ex))
318318
end
319319

320+
function preallocate_expression(prototype::Expression, n::Union{Nothing,Integer}=nothing)
321+
return (; tree=preallocate_expression(DE.get_contents(prototype), n))
322+
end
323+
function DE.copy_node!(dest::NamedTuple, src::Expression)
324+
tree = DE.copy_node!(dest.tree, DE.get_contents(src))
325+
return DE.with_contents(src, tree)
326+
end
327+
320328
"""
321329
string_tree(
322330
ex::AbstractExpression,

src/ParametricExpression.jl

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ using DispatchDoctor: @stable, @unstable
44
using ChainRulesCore: ChainRulesCore as CRC, NoTangent, @thunk
55

66
using ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum
7-
using ..NodeModule: AbstractExpressionNode, Node, tree_mapreduce
7+
using ..NodeModule:
8+
AbstractExpressionNode, Node, tree_mapreduce, with_contents, with_metadata
89
using ..ExpressionModule: AbstractExpression, Metadata
910
using ..ChainRulesModule: NodeTangent
1011

@@ -17,7 +18,9 @@ import ..NodeModule:
1718
leaf_convert,
1819
leaf_hash,
1920
leaf_equal,
20-
branch_copy!
21+
branch_copy!,
22+
copy_node!,
23+
preallocate_expression
2124
import ..NodeUtilsModule:
2225
count_constant_nodes,
2326
index_constant_nodes,
@@ -444,6 +447,28 @@ end
444447
return node_type(; val=ex)
445448
end
446449
end
450+
function preallocate_expression(
451+
prototype::ParametricExpression, n::Union{Nothing,Integer}=nothing
452+
)
453+
return (;
454+
tree=preallocate_expression(get_contents(prototype), n),
455+
parameters=similar(get_metadata(prototype).parameters),
456+
)
457+
end
458+
function copy_node!(dest::NamedTuple, src::ParametricExpression)
459+
new_tree = copy_node!(dest.tree, get_contents(src))
460+
metadata = get_metadata(src)
461+
new_parameters = dest.parameters
462+
new_parameters .= metadata.parameters
463+
new_metadata = Metadata((;
464+
operators=metadata.operators,
465+
variable_names=metadata.variable_names,
466+
parameters=new_parameters,
467+
parameter_names=metadata.parameter_names,
468+
))
469+
# TODO: Better interface for this^
470+
return with_metadata(with_contents(src, new_tree), new_metadata)
471+
end
447472
###############################################################################
448473

449474
end

src/base.jl

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -488,23 +488,25 @@ end
488488
# In-place versions
489489

490490
"""
491-
copy_node!(dest::AbstractArray{N}, src::N; break_sharing::Val{BS}=Val(false)) where {BS,N<:AbstractExpressionNode}
491+
copy_node!(dest::AbstractArray{N}, src::N) where {BS,N<:AbstractExpressionNode}
492492
493493
Copy a node, recursively copying all children nodes, in-place to an
494494
array of pre-allocated nodes. This should result in no extra allocations.
495495
"""
496496
function copy_node!(
497-
dest::AbstractArray{N},
498-
src::N;
499-
break_sharing::Val{BS}=Val(false),
500-
ref::Base.RefValue{<:Integer}=Ref(0),
501-
) where {BS,N<:AbstractExpressionNode}
502-
ref.x = 0
497+
dest::AbstractArray{N}, src::N; ref::Union{Nothing,Base.RefValue{<:Integer}}=nothing
498+
) where {N<:AbstractExpressionNode}
499+
_ref = if ref === nothing
500+
Ref(0)
501+
else
502+
ref.x = 0
503+
ref
504+
end
503505
return tree_mapreduce(
504-
leaf -> leaf_copy!(@inbounds(dest[ref.x += 1]), leaf),
506+
leaf -> leaf_copy!(@inbounds(dest[_ref.x += 1]), leaf),
505507
identity,
506508
((p, c::Vararg{Any,M}) where {M}) ->
507-
branch_copy!(@inbounds(dest[ref.x += 1]), p, c...),
509+
branch_copy!(@inbounds(dest[_ref.x += 1]), p, c...),
508510
src,
509511
N;
510512
break_sharing=Val(BS),
@@ -533,6 +535,31 @@ function branch_copy!(
533535
return dest
534536
end
535537

538+
"""
539+
preallocate_expression(prototype::AbstractExpressionNode, n=nothing)
540+
541+
Preallocate an array of empty nodes matching the type of `prototype`. If `n` is provided, use that length, otherwise use `length(prototype)`.
542+
543+
A given return value of this will be passed to `copy_node!` as the first argument,
544+
so it should be compatible.
545+
"""
546+
function preallocate_expression(
547+
prototype::N, n::Union{Nothing,Integer}=nothing
548+
) where {T,N<:AbstractExpressionNode{T}}
549+
num_nodes = @something(n, length(prototype))
550+
return N[with_type_parameters(N, T)() for _ in 1:num_nodes]
551+
end
552+
553+
function copy_node!(::Nothing, src::AbstractExpression)
554+
return copy(src)
555+
end
556+
function preallocate_expression(::AbstractExpression, ::Union{Nothing,Integer}=nothing)
557+
return nothing
558+
end
559+
# We don't require users to overload this, as it's not part of the required interface.
560+
# Also, there's no way to generally do this from the required interface, so for backwards
561+
# compatibility, we just return nothing.
562+
536563
"""
537564
copy(tree::AbstractExpressionNode; break_sharing::Val=Val(false))
538565

0 commit comments

Comments
 (0)