Skip to content

WIP: memoization #23

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 165 additions & 18 deletions src/RootedTrees.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ export residual_order_condition, elementary_weight, derivative_weight

export count_trees

export partition_forest, partition_skeleton, all_partitions, PartitionIterator
export partition_forest, partition_forest!,
partition_skeleton,
all_partitions, PartitionIterator,
create_cache



Expand Down Expand Up @@ -373,17 +376,44 @@ Form the partition forest of the rooted tree `t` where edges marked with `false`
in the `edge_set` are removed. The ith value in the Boolean iterable `edge_set`
corresponds to the edge connecting node `i+1` in the level sequence to its parent.

See Section 2.3 of
See also [`partition_skeleton`](@ref) and [`partition_forest!`](@ref).

# References

Section 2.3 of
- Philippe Chartier, Ernst Hairer, Gilles Vilmart (2010)
Algebraic Structures of B-series
Foundations of Computational Mathematics
[DOI: 10.1007/s10208-010-9065-1](https://doi.org/10.1007/s10208-010-9065-1)
"""
function partition_forest(t::RootedTree, _edge_set)
@inline function partition_forest(t::RootedTree, edge_set)
partition_forest!(nothing, t, edge_set)
end

"""
partition_forest!(cache_forests, t::RootedTree, edge_set)

Like [`partition_forest`](@ref) but with possible speed-up from memoization
(at the cost of more memory usage in the `cache_forests`).

`cache_forests` should be a created by `create_cache(partition_forest!, t)`,
see [`create_cache`](@ref). It will be updated to store all partition forests
of trees used in the computation.
"""
function partition_forest!(cache_forests, t::RootedTree, _edge_set)
@boundscheck begin
@assert length(t.level_sequence) == length(_edge_set) + 1
end

# `cache_forests = nothing` is used internally to avoid duplicating logic
# for the implementation of `partition_forest`
if cache_forests !== nothing
index_t_edge_set = hash(t, hash(_edge_set))
if haskey(cache_forests, index_t_edge_set)
return cache_forests[index_t_edge_set]
end
end

edge_set = copy(_edge_set)
ls = copy(t.level_sequence)
T = eltype(ls)
Expand Down Expand Up @@ -415,7 +445,7 @@ function partition_forest(t::RootedTree, _edge_set)
subtree_edge_set = @view edge_set[subtree_root_index:subtree_last_index-1]

# Form the partition forest recursively
append!(forest, partition_forest(subtree, subtree_edge_set))
append!(forest, partition_forest!(cache_forests, subtree, subtree_edge_set))

# Remove the subtree from the base tree
deleteat!(ls, subtree_root_index:subtree_last_index)
Expand All @@ -427,9 +457,27 @@ function partition_forest(t::RootedTree, _edge_set)
# Decide whether canonical representations should be used. Disabling
# them will increase the performance.
push!(forest, rootedtree!(ls))

if cache_forests !== nothing
# TODO: partitions; shall we use a `copy` or even a `deepcopy` of the `forest`?
cache_forests[index_t_edge_set] = forest
end

return forest
end

"""
create_cache(::typeof(partition_forest!), t::RootedTree)

Creates the cache used for memoization in [`partition_forest!`](@ref).
"""
function create_cache(::typeof(partition_forest!), t::RootedTree)
# TODO: partitions; shall we use `copy(t)` here?
T = typeof(copy(t))
forests = Dict{UInt, Vector{T}}()
return forests
end


# TODO: partitions; add documentation in the README to make them public API
"""
Expand All @@ -439,7 +487,11 @@ Form the partition skeleton of the rooted tree `t`, i.e., the rooted tree obtain
by contracting each tree of the partition forest to a single vertex and re-establishing
the edges removed to obtain the partition forest.

See `partition_forest` and Section 2.3 of
See also [`partition_forest`](@ref).

# References

Section 2.3 of
- Philippe Chartier, Ernst Hairer, Gilles Vilmart (2010)
Algebraic Structures of B-series
Foundations of Computational Mathematics
Expand Down Expand Up @@ -481,6 +533,79 @@ function partition_skeleton(t::RootedTree, _edge_set)
end


"""
partition_skeleton!(cache_skeletons, t::RootedTree, edge_set)

Like [`partition_skeleton`](@ref) but with possible speed-up from memoization
(at the cost of more memory usage in the `cache_skeletons`).

`cache_skeletons` should be a created by `create_cache(partition_skeleton!, t)`,
see [`create_cache`](@ref). It will be updated to store all partition skeletons.
"""
function partition_skeleton!(cache_skeletons, t::RootedTree, _edge_set)
@boundscheck begin
@assert length(t.level_sequence) == length(_edge_set) + 1
end

# `cache_skeletons = nothing` is used internally to avoid duplicating logic
# for the implementation of `partition_skeleton`
if cache_skeletons !== nothing
index_t_edge_set = hash(t, hash(_edge_set))
if haskey(cache_skeletons, index_t_edge_set)
return cache_skeletons[index_t_edge_set]
end
end

edge_set = copy(_edge_set)
ls = copy(t.level_sequence)

while any(edge_set)
# Find next edge to contract
subtree_root_index = findfirst(==(true), edge_set) + 1

# Contract the corresponding edge by removing the subtree root and promoting
# the rest of the subtree
subtree_last_index = subtree_root_index + 1
while subtree_last_index <= length(ls)
if ls[subtree_last_index] > ls[subtree_root_index]
ls[subtree_last_index] -= 1
subtree_last_index += 1
else
break
end
end
# Remove the root node
deleteat!(ls, subtree_root_index)
deleteat!(edge_set, subtree_root_index-1)
end

# The level sequence `ls` will not automatically be a canonical representation.
# TODO: partitions;
# Decide whether canonical representations should be used. Disabling
# them will increase the performance.
skeleton = rootedtree!(ls)

if cache_skeletons !== nothing
# TODO: partitions; shall we use a `copy` or even a `deepcopy` of the `skeleton`?
cache_skeletons[index_t_edge_set] = skeleton
end

return skeleton
end

"""
create_cache(::typeof(partition_skeleton!), t::RootedTree)

Creates the cache used for memoization in [`partition_skeleton!`](@ref).
"""
function create_cache(::typeof(partition_skeleton!), t::RootedTree)
# TODO: partitions; shall we use `copy(t)` here?
T = typeof(copy(t))
skeletons = Dict{UInt, T}()
return skeletons
end


# TODO: partitions; add documentation in the README to make them public API
"""
all_partitions(t::RootedTree)
Expand Down Expand Up @@ -515,46 +640,68 @@ end

Iterator over all partition forests and skeletons of the rooted tree `t`.

See `partition_forest`, `partition_skeleton`, and Section 2.3 of
See [`partition_forest`](@ref), [`partition_skeleton`](@ref), and Section 2.3 of
- Philippe Chartier, Ernst Hairer, Gilles Vilmart (2010)
Algebraic Structures of B-series
Foundations of Computational Mathematics
[DOI: 10.1007/s10208-010-9065-1](https://doi.org/10.1007/s10208-010-9065-1)
"""
struct PartitionIterator{T<:RootedTree}
struct PartitionIterator{T<:RootedTree, Cache}
t::T
edge_set::Vector{Bool}
cache::Cache

function PartitionIterator(t::T) where {T<:RootedTree}
function PartitionIterator(t::T, cache::Cache) where {T<:RootedTree, Cache}
edge_set = zeros(Bool, order(t) - 1)
new{T}(t, edge_set)
new{T, Cache}(t, edge_set, cache)
end
end

# TODO: partitions; document cache
# By default, do not use memoization in the PartitionIterator
PartitionIterator(t::RootedTree) = PartitionIterator(t, (nothing, nothing))

# By default, use memoization in the PartitionIterator
# function PartitionIterator(t::RootedTree)
# cache_forests = create_cache(partition_forest!, t)
# PartitionIterator(t, cache_forests)
# end

# By default, use memoization in the PartitionIterator with a global cache
# const CACHE_FORESTS = Dict{UInt, Vector{RootedTree{Int, Vector{Int}}}}()
# PartitionIterator(t::RootedTree) = PartitionIterator(t, CACHE_FORESTS)

"""
create_cache(::typeof(PartitionIterator), t::RootedTree)

Creates the cache used for memoization in [`PartitionIterator`](@ref).
"""
function create_cache(::typeof(PartitionIterator), t::RootedTree)
cache_forests = create_cache(partition_forest!, t)
cache_skeletons = create_cache(partition_skeleton!, t)
# cache_skeletons = nothing
return (cache_forests, cache_skeletons)
end

Base.IteratorSize(::Type{<:PartitionIterator}) = Base.HasLength()
Base.length(partitions::PartitionIterator) = 2^length(partitions.edge_set)
Base.eltype(::Type{PartitionIterator{T}}) where {T} = Tuple{Vector{T}, T}

function Base.iterate(partitions::PartitionIterator)
edge_set_value = 0
t = partitions.t
edge_set = partitions.edge_set

digits!(edge_set, edge_set_value, base=2)
forest = partition_forest(t, edge_set)
skeleton = partition_skeleton(t, edge_set)
((forest, skeleton), edge_set_value + 1)
iterate(partitions, edge_set_value)
end

function Base.iterate(partitions::PartitionIterator, edge_set_value)
edge_set_value >= length(partitions) && return nothing

t = partitions.t
edge_set = partitions.edge_set
cache_forests, cache_skeletons = partitions.cache

digits!(edge_set, edge_set_value, base=2)
forest = partition_forest(t, edge_set)
skeleton = partition_skeleton(t, edge_set)
forest = partition_forest!(cache_forests, t, edge_set)
skeleton = partition_skeleton!(cache_skeletons, t, edge_set)
((forest, skeleton), edge_set_value + 1)
end

Expand Down
42 changes: 42 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,9 @@ end
@test sort!(partition_forest(t, edge_set)) == sort!(reference_forest)
reference_skeleton = rootedtree([1, 2, 2])
@test reference_skeleton == partition_skeleton(t, edge_set)

_forests = create_cache(partition_forest!, t)
@test partition_forest(t, edge_set) == partition_forest!(_forests, t, edge_set)
end

let t = rootedtree([1, 2, 3, 4, 3])
Expand All @@ -294,6 +297,9 @@ end
@test sort!(partition_forest(t, edge_set)) == sort!(reference_forest)
reference_skeleton = rootedtree([1, 2, 3])
@test reference_skeleton == partition_skeleton(t, edge_set)

_forests = create_cache(partition_forest!, t)
@test partition_forest(t, edge_set) == partition_forest!(_forests, t, edge_set)
end

let t = rootedtree([1, 2, 3, 4, 3])
Expand All @@ -305,6 +311,9 @@ end
@test sort!(partition_forest(t, edge_set)) == sort!(reference_forest)
reference_skeleton = rootedtree([1, 2, 3, 3])
@test reference_skeleton == partition_skeleton(t, edge_set)

_forests = create_cache(partition_forest!, t)
@test partition_forest(t, edge_set) == partition_forest!(_forests, t, edge_set)
end

let t = rootedtree([1, 2, 2, 2, 2])
Expand All @@ -315,6 +324,9 @@ end
@test sort!(partition_forest(t, edge_set)) == sort!(reference_forest)
reference_skeleton = rootedtree([1, 2, 2])
@test reference_skeleton == partition_skeleton(t, edge_set)

_forests = create_cache(partition_forest!, t)
@test partition_forest(t, edge_set) == partition_forest!(_forests, t, edge_set)
end

let t = rootedtree([1, 2, 3, 2, 2])
Expand All @@ -326,6 +338,9 @@ end
@test sort!(partition_forest(t, edge_set)) == sort!(reference_forest)
reference_skeleton = rootedtree([1, 2, 3, 2])
@test reference_skeleton == partition_skeleton(t, edge_set)

_forests = create_cache(partition_forest!, t)
@test partition_forest(t, edge_set) == partition_forest!(_forests, t, edge_set)
end

let t = rootedtree([1, 2, 3, 2, 2])
Expand All @@ -334,6 +349,9 @@ end
@test sort!(partition_forest(t, edge_set)) == sort!(reference_forest)
reference_skeleton = rootedtree([1])
@test reference_skeleton == partition_skeleton(t, edge_set)

_forests = create_cache(partition_forest!, t)
@test partition_forest(t, edge_set) == partition_forest!(_forests, t, edge_set)
end

let t = rootedtree([1, 2, 3, 2, 3])
Expand All @@ -344,6 +362,9 @@ end
@test sort!(partition_forest(t, edge_set)) == sort!(reference_forest)
reference_skeleton = rootedtree([1, 2, 3])
@test reference_skeleton == partition_skeleton(t, edge_set)

_forests = create_cache(partition_forest!, t)
@test partition_forest(t, edge_set) == partition_forest!(_forests, t, edge_set)
end

let t = rootedtree([1, 2, 3, 2, 3])
Expand All @@ -355,6 +376,9 @@ end
@test sort!(partition_forest(t, edge_set)) == sort!(reference_forest)
reference_skeleton = rootedtree([1, 2, 2, 3])
@test reference_skeleton == partition_skeleton(t, edge_set)

_forests = create_cache(partition_forest!, t)
@test partition_forest(t, edge_set) == partition_forest!(_forests, t, edge_set)
end

let t = rootedtree([1, 2, 3, 3, 3])
Expand All @@ -365,6 +389,9 @@ end
@test sort!(partition_forest(t, edge_set)) == sort!(reference_forest)
reference_skeleton = rootedtree([1, 2, 3])
@test reference_skeleton == partition_skeleton(t, edge_set)

_forests = create_cache(partition_forest!, t)
@test partition_forest(t, edge_set) == partition_forest!(_forests, t, edge_set)
end

# additional tests not included in the examples of the paper
Expand All @@ -375,6 +402,21 @@ end
@test sort!(partition_forest(t, edge_set)) == sort!(reference_forest)
reference_skeleton = rootedtree([1, 2])
@test reference_skeleton == partition_skeleton(t, edge_set)

_forests = create_cache(partition_forest!, t)
@test partition_forest(t, edge_set) == partition_forest!(_forests, t, edge_set)
end

@testset "memoization" begin
for o in 1:8
t = rootedtree(collect(1:o))
edge_set = rand(Bool, order(t) - 1)
_forests = create_cache(partition_forest!, t)
for t in RootedTreeIterator(o)
edge_set = rand(Bool, order(t) - 1)
@test partition_forest(t, edge_set) == partition_forest!(_forests, t, edge_set)
end
end
end
end

Expand Down