diff --git a/src/RootedTrees.jl b/src/RootedTrees.jl index a3935baf..6d1d4da6 100644 --- a/src/RootedTrees.jl +++ b/src/RootedTrees.jl @@ -16,7 +16,10 @@ export residual_order_condition, elementary_weight, derivative_weight export count_trees -export partition_forest, partition_skeleton, all_partitions, PartitionIterator +export partition_forest, partition_forest!, + partition_skeleton, + all_partitions, PartitionIterator, + create_cache @@ -373,17 +376,44 @@ Form the partition forest of the rooted tree `t` where edges marked with `false` in the `edge_set` are removed. The ith value in the Boolean iterable `edge_set` corresponds to the edge connecting node `i+1` in the level sequence to its parent. -See Section 2.3 of +See also [`partition_skeleton`](@ref) and [`partition_forest!`](@ref). + +# References + +Section 2.3 of - Philippe Chartier, Ernst Hairer, Gilles Vilmart (2010) Algebraic Structures of B-series Foundations of Computational Mathematics [DOI: 10.1007/s10208-010-9065-1](https://doi.org/10.1007/s10208-010-9065-1) """ -function partition_forest(t::RootedTree, _edge_set) +@inline function partition_forest(t::RootedTree, edge_set) + partition_forest!(nothing, t, edge_set) +end + +""" + partition_forest!(cache_forests, t::RootedTree, edge_set) + +Like [`partition_forest`](@ref) but with possible speed-up from memoization +(at the cost of more memory usage in the `cache_forests`). + +`cache_forests` should be a created by `create_cache(partition_forest!, t)`, +see [`create_cache`](@ref). It will be updated to store all partition forests +of trees used in the computation. +""" +function partition_forest!(cache_forests, t::RootedTree, _edge_set) @boundscheck begin @assert length(t.level_sequence) == length(_edge_set) + 1 end + # `cache_forests = nothing` is used internally to avoid duplicating logic + # for the implementation of `partition_forest` + if cache_forests !== nothing + index_t_edge_set = hash(t, hash(_edge_set)) + if haskey(cache_forests, index_t_edge_set) + return cache_forests[index_t_edge_set] + end + end + edge_set = copy(_edge_set) ls = copy(t.level_sequence) T = eltype(ls) @@ -415,7 +445,7 @@ function partition_forest(t::RootedTree, _edge_set) subtree_edge_set = @view edge_set[subtree_root_index:subtree_last_index-1] # Form the partition forest recursively - append!(forest, partition_forest(subtree, subtree_edge_set)) + append!(forest, partition_forest!(cache_forests, subtree, subtree_edge_set)) # Remove the subtree from the base tree deleteat!(ls, subtree_root_index:subtree_last_index) @@ -427,9 +457,27 @@ function partition_forest(t::RootedTree, _edge_set) # Decide whether canonical representations should be used. Disabling # them will increase the performance. push!(forest, rootedtree!(ls)) + + if cache_forests !== nothing + # TODO: partitions; shall we use a `copy` or even a `deepcopy` of the `forest`? + cache_forests[index_t_edge_set] = forest + end + return forest end +""" + create_cache(::typeof(partition_forest!), t::RootedTree) + +Creates the cache used for memoization in [`partition_forest!`](@ref). +""" +function create_cache(::typeof(partition_forest!), t::RootedTree) + # TODO: partitions; shall we use `copy(t)` here? + T = typeof(copy(t)) + forests = Dict{UInt, Vector{T}}() + return forests +end + # TODO: partitions; add documentation in the README to make them public API """ @@ -439,7 +487,11 @@ Form the partition skeleton of the rooted tree `t`, i.e., the rooted tree obtain by contracting each tree of the partition forest to a single vertex and re-establishing the edges removed to obtain the partition forest. -See `partition_forest` and Section 2.3 of +See also [`partition_forest`](@ref). + +# References + +Section 2.3 of - Philippe Chartier, Ernst Hairer, Gilles Vilmart (2010) Algebraic Structures of B-series Foundations of Computational Mathematics @@ -481,6 +533,79 @@ function partition_skeleton(t::RootedTree, _edge_set) end +""" + partition_skeleton!(cache_skeletons, t::RootedTree, edge_set) + +Like [`partition_skeleton`](@ref) but with possible speed-up from memoization +(at the cost of more memory usage in the `cache_skeletons`). + +`cache_skeletons` should be a created by `create_cache(partition_skeleton!, t)`, +see [`create_cache`](@ref). It will be updated to store all partition skeletons. +""" +function partition_skeleton!(cache_skeletons, t::RootedTree, _edge_set) + @boundscheck begin + @assert length(t.level_sequence) == length(_edge_set) + 1 + end + + # `cache_skeletons = nothing` is used internally to avoid duplicating logic + # for the implementation of `partition_skeleton` + if cache_skeletons !== nothing + index_t_edge_set = hash(t, hash(_edge_set)) + if haskey(cache_skeletons, index_t_edge_set) + return cache_skeletons[index_t_edge_set] + end + end + + edge_set = copy(_edge_set) + ls = copy(t.level_sequence) + + while any(edge_set) + # Find next edge to contract + subtree_root_index = findfirst(==(true), edge_set) + 1 + + # Contract the corresponding edge by removing the subtree root and promoting + # the rest of the subtree + subtree_last_index = subtree_root_index + 1 + while subtree_last_index <= length(ls) + if ls[subtree_last_index] > ls[subtree_root_index] + ls[subtree_last_index] -= 1 + subtree_last_index += 1 + else + break + end + end + # Remove the root node + deleteat!(ls, subtree_root_index) + deleteat!(edge_set, subtree_root_index-1) + end + + # The level sequence `ls` will not automatically be a canonical representation. + # TODO: partitions; + # Decide whether canonical representations should be used. Disabling + # them will increase the performance. + skeleton = rootedtree!(ls) + + if cache_skeletons !== nothing + # TODO: partitions; shall we use a `copy` or even a `deepcopy` of the `skeleton`? + cache_skeletons[index_t_edge_set] = skeleton + end + + return skeleton +end + +""" + create_cache(::typeof(partition_skeleton!), t::RootedTree) + +Creates the cache used for memoization in [`partition_skeleton!`](@ref). +""" +function create_cache(::typeof(partition_skeleton!), t::RootedTree) + # TODO: partitions; shall we use `copy(t)` here? + T = typeof(copy(t)) + skeletons = Dict{UInt, T}() + return skeletons +end + + # TODO: partitions; add documentation in the README to make them public API """ all_partitions(t::RootedTree) @@ -515,35 +640,56 @@ end Iterator over all partition forests and skeletons of the rooted tree `t`. -See `partition_forest`, `partition_skeleton`, and Section 2.3 of +See [`partition_forest`](@ref), [`partition_skeleton`](@ref), and Section 2.3 of - Philippe Chartier, Ernst Hairer, Gilles Vilmart (2010) Algebraic Structures of B-series Foundations of Computational Mathematics [DOI: 10.1007/s10208-010-9065-1](https://doi.org/10.1007/s10208-010-9065-1) """ -struct PartitionIterator{T<:RootedTree} +struct PartitionIterator{T<:RootedTree, Cache} t::T edge_set::Vector{Bool} + cache::Cache - function PartitionIterator(t::T) where {T<:RootedTree} + function PartitionIterator(t::T, cache::Cache) where {T<:RootedTree, Cache} edge_set = zeros(Bool, order(t) - 1) - new{T}(t, edge_set) + new{T, Cache}(t, edge_set, cache) end end +# TODO: partitions; document cache +# By default, do not use memoization in the PartitionIterator +PartitionIterator(t::RootedTree) = PartitionIterator(t, (nothing, nothing)) + +# By default, use memoization in the PartitionIterator +# function PartitionIterator(t::RootedTree) +# cache_forests = create_cache(partition_forest!, t) +# PartitionIterator(t, cache_forests) +# end + +# By default, use memoization in the PartitionIterator with a global cache +# const CACHE_FORESTS = Dict{UInt, Vector{RootedTree{Int, Vector{Int}}}}() +# PartitionIterator(t::RootedTree) = PartitionIterator(t, CACHE_FORESTS) + +""" + create_cache(::typeof(PartitionIterator), t::RootedTree) + +Creates the cache used for memoization in [`PartitionIterator`](@ref). +""" +function create_cache(::typeof(PartitionIterator), t::RootedTree) + cache_forests = create_cache(partition_forest!, t) + cache_skeletons = create_cache(partition_skeleton!, t) + # cache_skeletons = nothing + return (cache_forests, cache_skeletons) +end + Base.IteratorSize(::Type{<:PartitionIterator}) = Base.HasLength() Base.length(partitions::PartitionIterator) = 2^length(partitions.edge_set) Base.eltype(::Type{PartitionIterator{T}}) where {T} = Tuple{Vector{T}, T} function Base.iterate(partitions::PartitionIterator) edge_set_value = 0 - t = partitions.t - edge_set = partitions.edge_set - - digits!(edge_set, edge_set_value, base=2) - forest = partition_forest(t, edge_set) - skeleton = partition_skeleton(t, edge_set) - ((forest, skeleton), edge_set_value + 1) + iterate(partitions, edge_set_value) end function Base.iterate(partitions::PartitionIterator, edge_set_value) @@ -551,10 +697,11 @@ function Base.iterate(partitions::PartitionIterator, edge_set_value) t = partitions.t edge_set = partitions.edge_set + cache_forests, cache_skeletons = partitions.cache digits!(edge_set, edge_set_value, base=2) - forest = partition_forest(t, edge_set) - skeleton = partition_skeleton(t, edge_set) + forest = partition_forest!(cache_forests, t, edge_set) + skeleton = partition_skeleton!(cache_skeletons, t, edge_set) ((forest, skeleton), edge_set_value + 1) end diff --git a/test/runtests.jl b/test/runtests.jl index 0d5a83ca..70d10ddc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -284,6 +284,9 @@ end @test sort!(partition_forest(t, edge_set)) == sort!(reference_forest) reference_skeleton = rootedtree([1, 2, 2]) @test reference_skeleton == partition_skeleton(t, edge_set) + + _forests = create_cache(partition_forest!, t) + @test partition_forest(t, edge_set) == partition_forest!(_forests, t, edge_set) end let t = rootedtree([1, 2, 3, 4, 3]) @@ -294,6 +297,9 @@ end @test sort!(partition_forest(t, edge_set)) == sort!(reference_forest) reference_skeleton = rootedtree([1, 2, 3]) @test reference_skeleton == partition_skeleton(t, edge_set) + + _forests = create_cache(partition_forest!, t) + @test partition_forest(t, edge_set) == partition_forest!(_forests, t, edge_set) end let t = rootedtree([1, 2, 3, 4, 3]) @@ -305,6 +311,9 @@ end @test sort!(partition_forest(t, edge_set)) == sort!(reference_forest) reference_skeleton = rootedtree([1, 2, 3, 3]) @test reference_skeleton == partition_skeleton(t, edge_set) + + _forests = create_cache(partition_forest!, t) + @test partition_forest(t, edge_set) == partition_forest!(_forests, t, edge_set) end let t = rootedtree([1, 2, 2, 2, 2]) @@ -315,6 +324,9 @@ end @test sort!(partition_forest(t, edge_set)) == sort!(reference_forest) reference_skeleton = rootedtree([1, 2, 2]) @test reference_skeleton == partition_skeleton(t, edge_set) + + _forests = create_cache(partition_forest!, t) + @test partition_forest(t, edge_set) == partition_forest!(_forests, t, edge_set) end let t = rootedtree([1, 2, 3, 2, 2]) @@ -326,6 +338,9 @@ end @test sort!(partition_forest(t, edge_set)) == sort!(reference_forest) reference_skeleton = rootedtree([1, 2, 3, 2]) @test reference_skeleton == partition_skeleton(t, edge_set) + + _forests = create_cache(partition_forest!, t) + @test partition_forest(t, edge_set) == partition_forest!(_forests, t, edge_set) end let t = rootedtree([1, 2, 3, 2, 2]) @@ -334,6 +349,9 @@ end @test sort!(partition_forest(t, edge_set)) == sort!(reference_forest) reference_skeleton = rootedtree([1]) @test reference_skeleton == partition_skeleton(t, edge_set) + + _forests = create_cache(partition_forest!, t) + @test partition_forest(t, edge_set) == partition_forest!(_forests, t, edge_set) end let t = rootedtree([1, 2, 3, 2, 3]) @@ -344,6 +362,9 @@ end @test sort!(partition_forest(t, edge_set)) == sort!(reference_forest) reference_skeleton = rootedtree([1, 2, 3]) @test reference_skeleton == partition_skeleton(t, edge_set) + + _forests = create_cache(partition_forest!, t) + @test partition_forest(t, edge_set) == partition_forest!(_forests, t, edge_set) end let t = rootedtree([1, 2, 3, 2, 3]) @@ -355,6 +376,9 @@ end @test sort!(partition_forest(t, edge_set)) == sort!(reference_forest) reference_skeleton = rootedtree([1, 2, 2, 3]) @test reference_skeleton == partition_skeleton(t, edge_set) + + _forests = create_cache(partition_forest!, t) + @test partition_forest(t, edge_set) == partition_forest!(_forests, t, edge_set) end let t = rootedtree([1, 2, 3, 3, 3]) @@ -365,6 +389,9 @@ end @test sort!(partition_forest(t, edge_set)) == sort!(reference_forest) reference_skeleton = rootedtree([1, 2, 3]) @test reference_skeleton == partition_skeleton(t, edge_set) + + _forests = create_cache(partition_forest!, t) + @test partition_forest(t, edge_set) == partition_forest!(_forests, t, edge_set) end # additional tests not included in the examples of the paper @@ -375,6 +402,21 @@ end @test sort!(partition_forest(t, edge_set)) == sort!(reference_forest) reference_skeleton = rootedtree([1, 2]) @test reference_skeleton == partition_skeleton(t, edge_set) + + _forests = create_cache(partition_forest!, t) + @test partition_forest(t, edge_set) == partition_forest!(_forests, t, edge_set) + end + + @testset "memoization" begin + for o in 1:8 + t = rootedtree(collect(1:o)) + edge_set = rand(Bool, order(t) - 1) + _forests = create_cache(partition_forest!, t) + for t in RootedTreeIterator(o) + edge_set = rand(Bool, order(t) - 1) + @test partition_forest(t, edge_set) == partition_forest!(_forests, t, edge_set) + end + end end end