Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DynamicExpressions"
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
authors = ["MilesCranmer <[email protected]>"]
version = "0.19.3"
version = "1.0.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
22 changes: 11 additions & 11 deletions src/Expression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,23 +274,23 @@ copy_node(ex::AbstractExpression; kws...) = copy(ex)
count_nodes(ex::AbstractExpression; kws...) = count_nodes(get_tree(ex); kws...)

function tree_mapreduce(
f::Function,
op::Function,
f::F,
op::G,
ex::AbstractExpression,
result_type::Type=Undefined;
result_type::Type{RT}=Undefined;
kws...,
)
return tree_mapreduce(f, op, get_tree(ex), result_type; kws...)
) where {F<:Function,G<:Function,RT}
return tree_mapreduce(f, op, get_tree(ex), RT; kws...)
end
function tree_mapreduce(
f_leaf::Function,
f_branch::Function,
op::Function,
f_leaf::F,
f_branch::G,
op::H,
ex::AbstractExpression,
result_type::Type=Undefined;
result_type::Type{RT}=Undefined;
kws...,
)
return tree_mapreduce(f_leaf, f_branch, op, get_tree(ex), result_type; kws...)
) where {F<:Function,G<:Function,H<:Function,RT}
return tree_mapreduce(f_leaf, f_branch, op, get_tree(ex), RT; kws...)
end

count_constant_nodes(ex::AbstractExpression) = count_constant_nodes(get_tree(ex))
Expand Down
17 changes: 7 additions & 10 deletions src/Node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,16 +218,13 @@ include("base.jl")
end
return node_factory(N, T1, val, feature, op, l, r, allocator)
end
function validate_not_all_defaults(::Type{N}, val, feature, op, l, r, children) where {N<:AbstractExpressionNode}
return nothing
end
function validate_not_all_defaults(::Type{N}, val, feature, op, l, r, children) where {T,N<:AbstractExpressionNode{T}}
if val === nothing && feature === nothing && op === nothing && l === nothing && r === nothing && children === nothing
error(
"Encountered the call for $N() inside the generic constructor. "
* "Did you forget to define `$(Base.typename(N).wrapper){T}() where {T} = new{T}()`?"
)
end
validate_not_all_defaults(::Type{<:AbstractExpressionNode}, val, feature, op, l, r, children) = nothing
validate_not_all_defaults(::Type{<:AbstractExpressionNode{T}}, val, feature, op, l, r, children) where {T} = nothing
function validate_not_all_defaults(::Type{N}, ::Nothing, ::Nothing, ::Nothing, ::Nothing, ::Nothing, ::Nothing) where {T,N<:AbstractExpressionNode{T}}
error(
"Encountered the call for $N() inside the generic constructor. "
* "Did you forget to define `$(Base.typename(N).wrapper){T}() where {T} = new{T}()`?"
)
return nothing
end
"""Create a constant leaf."""
Expand Down
6 changes: 3 additions & 3 deletions src/StructuredExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ kws = (;
f = parse_expression(:(x * x - cos(2.5f0 * y + -0.5f0)); kws...)
g = parse_expression(:(exp(-(y * y))); kws...)

f_plus_g = StructuredExpression((; f, g), nt -> nt.f + nt.g)
f_plus_g = StructuredExpression((; f, g); structure=nt -> nt.f + nt.g)
```

Now, when evaluating `f_plus_g`, this expression type will
Expand Down Expand Up @@ -83,8 +83,8 @@ struct StructuredExpression{
end

function StructuredExpression(
trees::NamedTuple,
structure::F;
trees::NamedTuple;
structure::F,
operators::Union{AbstractOperatorEnum,Nothing}=nothing,
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
extra...,
Expand Down
97 changes: 53 additions & 44 deletions src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ function tree_mapreduce(
tree::AbstractNode,
result_type::Type{RT}=Undefined;
f_on_shared::H=(result, is_shared) -> result,
break_sharing=Val(false),
) where {RT,F<:Function,G<:Function,H<:Function}
return tree_mapreduce(f, f, op, tree, RT; f_on_shared, break_sharing)
break_sharing::Val{BS}=Val(false),
) where {RT,F<:Function,G<:Function,H<:Function,BS}
return tree_mapreduce(f, f, op, tree, RT; f_on_shared, break_sharing=Val(BS))
end
function tree_mapreduce(
f_leaf::F1,
Expand All @@ -92,8 +92,8 @@ function tree_mapreduce(
tree::AbstractNode,
result_type::Type{RT}=Undefined;
f_on_shared::H=(result, is_shared) -> result,
break_sharing::Val=Val(false),
) where {F1<:Function,F2<:Function,G<:Function,H<:Function,RT}
break_sharing::Val{BS}=Val(false),
) where {F1<:Function,F2<:Function,G<:Function,H<:Function,RT,BS}

# Trick taken from here:
# https://discourse.julialang.org/t/recursive-inner-functions-a-thousand-times-slower/85604/5
Expand All @@ -108,7 +108,7 @@ function tree_mapreduce(
end
end

sharing = preserve_sharing(typeof(tree)) && break_sharing === Val(false)
sharing = preserve_sharing(typeof(tree)) && !BS

RT == Undefined &&
sharing &&
Expand Down Expand Up @@ -222,14 +222,14 @@ end

Count the number of nodes in the tree.
"""
function count_nodes(tree::AbstractNode; break_sharing=Val(false))
function count_nodes(tree::AbstractNode; break_sharing::Val{BS}=Val(false)) where {BS}
return tree_mapreduce(
_ -> 1,
+,
tree,
Int64;
f_on_shared=(c, is_shared) -> is_shared ? 0 : c,
break_sharing,
break_sharing=Val(BS),
)
end

Expand All @@ -239,10 +239,14 @@ end
Apply a function to each node in a tree without returning the results.
"""
function foreach(
f::F, tree::AbstractNode; break_sharing::Val=Val(false)
) where {F<:Function}
f::F, tree::AbstractNode; break_sharing::Val{BS}=Val(false)
) where {F<:Function,BS}
tree_mapreduce(
t -> (@inline(f(t)); nothing), Returns(nothing), tree, Nothing; break_sharing
t -> (@inline(f(t)); nothing),
Returns(nothing),
tree,
Nothing;
break_sharing=Val(BS),
)
return nothing
end
Expand All @@ -260,10 +264,10 @@ function filter_map(
map_fnc::G,
tree::AbstractNode,
result_type::Type{GT};
break_sharing::Val=Val(false),
) where {F<:Function,G<:Function,GT}
stack = Array{GT}(undef, count(filter_fnc, tree; init=0, break_sharing))
filter_map!(filter_fnc, map_fnc, stack, tree; break_sharing)
break_sharing::Val{BS}=Val(false),
) where {F<:Function,G<:Function,GT,BS}
stack = Array{GT}(undef, count(filter_fnc, tree; init=0, break_sharing=Val(BS)))
filter_map!(filter_fnc, map_fnc, stack, tree; break_sharing=Val(BS))
return stack::Vector{GT}
end

Expand All @@ -277,10 +281,10 @@ function filter_map!(
map_fnc::G,
destination::Vector{GT},
tree::AbstractNode;
break_sharing::Val=Val(false),
) where {GT,F<:Function,G<:Function}
break_sharing::Val{BS}=Val(false),
) where {GT,F<:Function,G<:Function,BS}
pointer = Ref(0)
foreach(tree; break_sharing) do t
foreach(tree; break_sharing=Val(BS)) do t
if @inline(filter_fnc(t))
map_result = @inline(map_fnc(t))::GT
@inbounds destination[pointer.x += 1] = map_result
Expand All @@ -294,55 +298,60 @@ end

Filter nodes of a tree, returning a flat array of the nodes for which the function returns `true`.
"""
function filter(f::F, tree::AbstractNode; break_sharing::Val=Val(false)) where {F<:Function}
return filter_map(f, identity, tree, typeof(tree); break_sharing)
function filter(
f::F, tree::AbstractNode; break_sharing::Val{BS}=Val(false)
) where {F<:Function,BS}
return filter_map(f, identity, tree, typeof(tree); break_sharing=Val(BS))
end

"""
collect(tree::AbstractNode; break_sharing::Val=Val(false))

Collect all nodes in a tree into a flat array in depth-first order.
"""
function collect(tree::AbstractNode; break_sharing::Val=Val(false))
return filter(Returns(true), tree; break_sharing)
function collect(tree::AbstractNode; break_sharing::Val{BS}=Val(false)) where {BS}
return filter(Returns(true), tree; break_sharing=Val(BS))
end

"""
map(f::F, tree::AbstractNode, result_type::Type{RT}=Nothing; break_sharing::Val=Val(false)) where {F<:Function,RT}
map(f::F, tree::AbstractNode, result_type::Type{RT}=Nothing; break_sharing::Val{BS}=Val(false)) where {F<:Function,RT,BS}

Map a function over a tree and return a flat array of the results in depth-first order.
Pre-specifying the `result_type` of the function can be used to avoid extra allocations.
"""
function map(
f::F, tree::AbstractNode, result_type::Type{RT}=Nothing; break_sharing::Val=Val(false)
) where {F<:Function,RT}
f::F,
tree::AbstractNode,
result_type::Type{RT}=Nothing;
break_sharing::Val{BS}=Val(false),
) where {F<:Function,RT,BS}
if RT == Nothing
return map(f, collect(tree; break_sharing))
return map(f, collect(tree; break_sharing=Val(BS)))
else
return filter_map(Returns(true), f, tree, result_type; break_sharing)
return filter_map(Returns(true), f, tree, result_type; break_sharing=Val(BS))
end
end

"""
count(f::F, tree::AbstractNode; init=0, break_sharing::Val=Val(false)) where {F<:Function}
count(f::F, tree::AbstractNode; init=0, break_sharing::Val{BS}=Val(false)) where {F<:Function,BS}

Count the number of nodes in a tree for which the function returns `true`.
"""
function count(
f::F, tree::AbstractNode; init=0, break_sharing::Val=Val(false)
) where {F<:Function}
f::F, tree::AbstractNode; init=0, break_sharing::Val{BS}=Val(false)
) where {F<:Function,BS}
return tree_mapreduce(
t -> @inline(f(t)) ? 1 : 0,
+,
tree,
Int64;
f_on_shared=(c, is_shared) -> is_shared ? 0 : c,
break_sharing,
break_sharing=Val(BS),
) + init
end

"""
sum(f::Function, tree::AbstractNode; result_type=Undefined, f_on_shared=_default_shared_aggregation, break_sharing::Val=Val(false)) where {F<:Function}
sum(f::Function, tree::AbstractNode; result_type=Undefined, f_on_shared=_default_shared_aggregation, break_sharing::Val{BS}=Val(false)) where {F<:Function,BS}

Sum the results of a function over a tree. For graphs with shared nodes
such as [`GraphNode`](@ref), the function `f_on_shared` is called on the result
Expand Down Expand Up @@ -386,7 +395,7 @@ function mapreduce(
"Must specify `result_type` as a keyword argument to `mapreduce` if `preserve_sharing` is true."
)
end
return tree_mapreduce(f, op, tree, RT; f_on_shared, break_sharing)
return tree_mapreduce(f, op, tree, RT; f_on_shared, break_sharing=Val(BS))
end

isempty(::AbstractNode) = false
Expand All @@ -396,8 +405,8 @@ end
@unstable iterate(::AbstractNode, stack) =
isempty(stack) ? nothing : (popfirst!(stack), stack)
in(item, tree::AbstractNode) = any(t -> t == item, tree)
function length(tree::AbstractNode; break_sharing::Val=Val(false))
return count_nodes(tree; break_sharing)
function length(tree::AbstractNode; break_sharing::Val{BS}=Val(false)) where {BS}
return count_nodes(tree; break_sharing=Val(BS))
end

"""
Expand All @@ -407,8 +416,8 @@ Compute a hash of a tree. This will compute a hash differently
if nodes are shared in a tree. This is ignored if `break_sharing` is set to `Val(true)`.
"""
function hash(
tree::AbstractExpressionNode{T}, h::UInt=zero(UInt); break_sharing::Val=Val(false)
) where {T}
tree::AbstractExpressionNode{T}, h::UInt=zero(UInt); break_sharing::Val{BS}=Val(false)
) where {T,BS}
return tree_mapreduce(
t -> leaf_hash(h, t),
identity,
Expand All @@ -417,7 +426,7 @@ function hash(
UInt;
f_on_shared=(cur_hash, is_shared) ->
is_shared ? hash((:shared, cur_hash), h) : cur_hash,
break_sharing,
break_sharing=Val(BS),
)
end
function leaf_hash(h::UInt, t::AbstractExpressionNode)
Expand All @@ -428,17 +437,17 @@ function branch_hash(h::UInt, t::AbstractExpressionNode, children::Vararg{Any,M}
end

"""
copy_node(tree::AbstractExpressionNode; break_sharing::Val=Val(false))
copy_node(tree::AbstractExpressionNode; break_sharing::Val{BS}=Val(false)) where {BS}

Copy a node, recursively copying all children nodes.
This is more efficient than the built-in copy.

If `break_sharing` is set to `Val(true)`, sharing in a tree will be ignored.
"""
function copy_node(
tree::N; break_sharing::Val=Val(false)
) where {T,N<:AbstractExpressionNode{T}}
return tree_mapreduce(leaf_copy, identity, branch_copy, tree, N; break_sharing)
tree::N; break_sharing::Val{BS}=Val(false)
) where {T,N<:AbstractExpressionNode{T},BS}
return tree_mapreduce(leaf_copy, identity, branch_copy, tree, N; break_sharing=Val(BS))
end
function leaf_copy(t::N) where {T,N<:AbstractExpressionNode{T}}
if t.constant
Expand All @@ -459,8 +468,8 @@ This is more efficient than the built-in copy.

If `break_sharing` is set to `Val(true)`, sharing in a tree will be ignored.
"""
function copy(tree::AbstractExpressionNode; break_sharing::Val=Val(false))
return copy_node(tree; break_sharing)
function copy(tree::AbstractExpressionNode; break_sharing::Val{BS}=Val(false)) where {BS}
return copy_node(tree; break_sharing=Val(BS))
end

"""
Expand Down
16 changes: 9 additions & 7 deletions test/test_structured_expression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

shower(ex) = sprint((io, e) -> show(io, MIME"text/plain"(), e), ex)

f_plus_g = StructuredExpression((; f, g), nt -> nt.f + nt.g)
f_div_g = StructuredExpression((; f, g), nt -> nt.f / nt.g)
cos_f = StructuredExpression((; f), nt -> cos(nt.f))
exp_g = StructuredExpression((; g), nt -> exp(nt.g))
f_plus_g = StructuredExpression((; f, g); structure=nt -> nt.f + nt.g)
f_div_g = StructuredExpression((; f, g); structure=nt -> nt.f / nt.g)
cos_f = StructuredExpression((; f); structure=nt -> cos(nt.f))
exp_g = StructuredExpression((; g); structure=nt -> exp(nt.g))

@test shower(f_plus_g) == "((x * x) - cos((2.5 * y) + -0.5)) + exp(-(y * y))"
@test shower(f_div_g) == "((x * x) - cos((2.5 * y) + -0.5)) / exp(-(y * y))"
Expand Down Expand Up @@ -43,7 +43,7 @@ end
f = parse_expression(:(x * x - cos(2.5f0 * y + -0.5f0)); kws...)
g = parse_expression(:(exp(-(y * y))); kws...)

ex = StructuredExpression((; f, g), nt -> nt.f + nt.g)
ex = StructuredExpression((; f, g); structure=nt -> nt.f + nt.g)

@test test(ExpressionInterface, StructuredExpression, [ex])
end
Expand All @@ -64,7 +64,7 @@ end
g = parse_expression(:(exp(-(y * y))); kws...)

c = [1]
ex = StructuredExpression((; f, g), my_factory; a=c)
ex = StructuredExpression((; f, g); structure=my_factory, a=c)

@test ex.metadata.extra.a[] == 1
@test ex.metadata.extra.a === c
Expand Down Expand Up @@ -114,7 +114,9 @@ end
This is a composite `AbstractExpression` object that composes multiple
expressions during evaluation.
=#
ex = StructuredExpression((; f, g), nt -> nt.f + nt.g; operators, variable_names)
ex = StructuredExpression(
(; f, g); structure=nt -> nt.f + nt.g, operators, variable_names
)
ex
@test typeof(ex) <: AbstractExpression{Float64,<:Node{Float64}} #src
#=
Expand Down
Loading