Skip to content
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
version = "2.4.1"

[deps]
BorrowChecker = "7bdcaa52-c310-4bb0-bf54-d941056ed284"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8"
Expand All @@ -29,6 +30,7 @@ DynamicExpressionsSymbolicUtilsExt = "SymbolicUtils"
DynamicExpressionsZygoteExt = "Zygote"

[compat]
BorrowChecker = "0.4"
Bumper = "0.6"
ChainRulesCore = "1"
Compat = "4.16"
Expand Down
3 changes: 3 additions & 0 deletions src/DynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ module DynamicExpressions

using DispatchDoctor: @stable, @unstable

using BorrowChecker: BorrowChecker
BorrowChecker.PreferencesModule.disable_by_default!(@__MODULE__)

@stable default_mode = "disable" begin
include("Utils.jl")
include("ValueInterface.jl")
Expand Down
34 changes: 23 additions & 11 deletions src/NodePreallocation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ using ..NodeModule:
set_node!,
set_children!

using BorrowChecker: BorrowChecker
using BorrowChecker: @unsafe

"""
allocate_container(prototype::AbstractExpressionNode, n=nothing)

Expand All @@ -18,7 +21,7 @@ If `n` is not provided, it will be computed from `length(prototype)`.
A given return value of this will be passed to `copy_into!` as the first argument,
so it should be compatible.
"""
function allocate_container(
BorrowChecker.@safe function allocate_container(
prototype::N, n::Union{Nothing,Integer}=nothing
) where {T,N<:AbstractExpressionNode{T}}
num_nodes = @something(n, length(prototype))
Expand All @@ -33,27 +36,36 @@ This should result in no extra allocations.
"""
function copy_into!(
dest::AbstractArray{N}, src::N; ref::Union{Nothing,Base.RefValue{<:Integer}}=nothing
) where {N<:AbstractExpressionNode}
return copy_into!(dest, src, ref)
end

BorrowChecker.@safe function copy_into!(
dest::AbstractArray{N}, src::N, ref::Union{Nothing,Base.RefValue{<:Integer}}
) where {N<:AbstractExpressionNode}
_ref = if ref === nothing
Ref(0)
else
ref.x = 0
ref
end
return tree_mapreduce(
leaf -> leaf_copy_into!(@inbounds(dest[_ref.x += 1]), leaf),
identity,
((p, c::Vararg{Any,M}) where {M}) ->
branch_copy_into!(@inbounds(dest[_ref.x += 1]), p, c...),
src,
N,
)
return @unsafe begin
tree_mapreduce(
leaf -> leaf_copy_into!(@inbounds(dest[_ref.x += 1]), leaf),
identity,
(p, c...) -> branch_copy_into!(@inbounds(dest[_ref.x += 1]), p, c...),
src,
N,
)
end
end
function leaf_copy_into!(dest::N, src::N) where {N<:AbstractExpressionNode}
BorrowChecker.@safe function leaf_copy_into!(
dest::N, src::N
) where {N<:AbstractExpressionNode}
set_node!(dest, src)
return dest
end
function branch_copy_into!(
BorrowChecker.@safe function branch_copy_into!(
dest::N, src::N, children::Vararg{Any,M}
) where {T,D,N<:AbstractExpressionNode{T,D},M}
dest.degree = M
Expand Down
19 changes: 16 additions & 3 deletions src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ import Base:
reduce,
sum

using BorrowChecker: BorrowChecker
using BorrowChecker: @unsafe

using DispatchDoctor: @unstable
using ..UtilsModule: Undefined

Expand Down Expand Up @@ -496,16 +499,26 @@ If `break_sharing` is set to `Val(true)`, sharing in a tree will be ignored.
function copy_node(
tree::N; break_sharing::Val{BS}=Val(false)
) where {T,N<:AbstractExpressionNode{T},BS}
return tree_mapreduce(leaf_copy, identity, branch_copy, tree, N; break_sharing=Val(BS))
return copy_node(tree, break_sharing)
end

BorrowChecker.@safe function copy_node(
tree::N, break_sharing::Val{BS}
) where {T,N<:AbstractExpressionNode{T},BS}
return @unsafe begin
tree_mapreduce(leaf_copy, identity, branch_copy, tree, N; break_sharing)
end
end
function leaf_copy(t::N) where {T,N<:AbstractExpressionNode{T}}
BorrowChecker.@safe function leaf_copy(t::N) where {T,N<:AbstractExpressionNode{T}}
if t.constant
return constructorof(N)(; val=t.val)
else
return constructorof(N)(T; feature=t.feature)
end
end
function branch_copy(t::N, children::Vararg{Any,M}) where {T,N<:AbstractExpressionNode{T},M}
BorrowChecker.@safe function branch_copy(
t::N, children::Vararg{Any,M}
) where {T,N<:AbstractExpressionNode{T},M}
return constructorof(N)(T; op=t.op, children)
end

Expand Down