Skip to content

WIP: introduce DeleteVector #53

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions src/DeleteVectors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
module DeleteVectors

export DeleteVector

# mutable struct DeleteVector{T, Parent <: AbstractVector{T}} <: AbstractVector{T}
# parent::Parent
# length::Int
# end

struct DeleteVector{T, Parent <: AbstractVector{T}} <: AbstractVector{T}
parent::Parent
length::Base.RefValue{Int}
end


function DeleteVector(parent::AbstractVector)
DeleteVector(parent, Ref(length(parent)))
end

function DeleteVector{T}(sizehint::Integer = 4) where {T}
sizehint >= 0 || throw(DomainError(sizehint, "Invalid initial size."))
DeleteVector(Vector{T}(undef, sizehint), Ref(0))
end

@inline Base.parent(v::DeleteVector) = v.parent
@inline Base.pointer(v::DeleteVector) = pointer(parent(v))
@inline Base.pointer(v::DeleteVector, i::Integer) = pointer(parent(v), i)

@inline Base.size(v::DeleteVector) = (v.length[], )
Base.IndexStyle(::DeleteVector) = IndexLinear()

@inline function Base.getindex(v::DeleteVector, i)
@boundscheck checkbounds(v, i)
@inbounds v.parent[i]
end

@inline function Base.setindex!(v::DeleteVector, x, i)
@boundscheck checkbounds(v, i)
@inbounds v.parent[i] = x
end


Base.copy(v::DeleteVector) = DeleteVector(copy(parent(v)), Ref(length(v)))
Base.similar(v::DeleteVector) = DeleteVector(similar(parent(v)), Ref(length(v)))

function Base.copyto!(dest::DeleteVector, doffs::Integer,
src::DeleteVector, soffs::Integer, n::Integer)
copyto!(parent(dest), doffs, parent(src), soffs, n)
end

Base.view(v::DeleteVector, inds::UnitRange) = view(parent(v), inds)


function Base.sizehint!(v::DeleteVector, n)
if length(parent(v)) < n || n >= length(v)
resize!(v.parent, n)
end
nothing
end

function Base.resize!(v::DeleteVector, n)
if length(parent(v)) < n
resize!(v.parent, n)
end
v.length[] = n
v
end

Base.empty!(v::DeleteVector) = (v.length[] = 0; v)


function Base.deleteat!(v::DeleteVector, i::Integer)
@boundscheck checkbounds(v, i)
p = parent(v)
for j in i+1:lastindex(v)
@inbounds p[j-1] = p[j]
end
v.length[] -= 1
v
end

function Base.deleteat!(v::DeleteVector, inds::UnitRange)
@boundscheck checkbounds(v, inds)
p = parent(v)
i = first(inds)
offset = length(inds)
for j in i+offset:lastindex(v)
@inbounds p[j-offset] = p[j]
end
v.length[] -= offset
v
end


Base.findprev(v::DeleteVector, i::Integer) = findprev(parent(v), i)
# TODO: findnext, findlast, findfirst, function arguments

end # module
42 changes: 26 additions & 16 deletions src/RootedTrees.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ using Latexify: Latexify
using RecipesBase: RecipesBase


include("DeleteVectors.jl")
using .DeleteVectors


export RootedTree, rootedtree, rootedtree!, RootedTreeIterator

export butcher_representation
Expand Down Expand Up @@ -369,7 +373,11 @@ end
const BUFFER_LENGTH = 64
const CANONICAL_REPRESENTATION_BUFFER = Vector{Vector{Int}}()

function canonical_representation!(t::RootedTree{Int, Vector{Int}})
function canonical_representation!(t::Union{
RootedTree{Int, Vector{Int}},
RootedTree{Int, DeleteVector{Int, Vector{Int}}},
RootedTree{Int, SubArray{Int, 1, Vector{Int}, Tuple{UnitRange{Int}}, true}},
RootedTree{Int, SubArray{Int, 1, Vector{Int}, Tuple{Base.OneTo{Int}}, true}}})
if order(t) <= BUFFER_LENGTH
buffer = CANONICAL_REPRESENTATION_BUFFER[Threads.threadid()]
else
Expand Down Expand Up @@ -616,10 +624,10 @@ Section 2.3 of
Foundations of Computational Mathematics
[DOI: 10.1007/s10208-010-9065-1](https://doi.org/10.1007/s10208-010-9065-1)
"""
struct PartitionForestIterator{T, V, Tree<:RootedTree{T, V}}
struct PartitionForestIterator{T, V, Tree<:RootedTree{T, V}, EdgeSet<:AbstractVector{Bool}}
t::Tree
level_sequence::V
edge_set::Vector{Bool}
edge_set::EdgeSet
end

function PartitionForestIterator(t::RootedTree, edge_set)
Expand Down Expand Up @@ -713,8 +721,8 @@ function partition_skeleton(t::RootedTree, edge_set)
end

edge_set_copy = copy(edge_set)
skeleton = RootedTree(copy(t.level_sequence), true)
return partition_skeleton!(skeleton.level_sequence, edge_set_copy)
level_sequence_copy = copy(t.level_sequence)
return partition_skeleton!(level_sequence_copy, edge_set_copy)
end

# internal in-place version of partition_skeleton modifying the inputs
Expand All @@ -730,7 +738,8 @@ function partition_skeleton!(level_sequence, edge_set)
# Remember the convention node = edge + 1
subtree_root_index = edge_to_contract + 1
subtree_last_index = subtree_root_index + 1
while subtree_last_index <= length(level_sequence)
length_level_sequence = length(level_sequence)
while subtree_last_index <= length_level_sequence
if level_sequence[subtree_last_index] > level_sequence[subtree_root_index]
level_sequence[subtree_last_index] -= 1
subtree_last_index += 1
Expand Down Expand Up @@ -787,7 +796,7 @@ end
# A helper function to comute the binary representation of an integer `n` as
# a vector of `Bool`s. This is a more efficient version of
# binary_digits!(digits, n) = digits!(digits, n, base=2)
function binary_digits!(digits::Vector{Bool}, n::Int)
function binary_digits!(digits::AbstractVector{Bool}, n::Int)
bit = 1
for i in eachindex(digits)
digits[i] = n & bit > 0
Expand Down Expand Up @@ -819,23 +828,24 @@ Section 2.3 of
Foundations of Computational Mathematics
[DOI: 10.1007/s10208-010-9065-1](https://doi.org/10.1007/s10208-010-9065-1)
"""
struct PartitionIterator{T, Tree<:RootedTree{T}}
struct PartitionIterator{T, Tree<:RootedTree{T}, EdgeSet<:AbstractVector{Bool}}
t::Tree
forest::PartitionForestIterator{T, Vector{T}, RootedTree{T, Vector{T}}}
skeleton::RootedTree{T, Vector{T}}
edge_set::Vector{Bool}
edge_set_tmp::Vector{Bool}
edge_set::EdgeSet
edge_set_tmp::EdgeSet
end

function PartitionIterator(t::Tree) where {T, Tree<:RootedTree{T}}
skeleton = RootedTree(Vector{T}(undef, order(t)), true)
edge_set = Vector{Bool}(undef, order(t) - 1)
edge_set = DeleteVector(Vector{Bool}(undef, order(t) - 1))
edge_set_tmp = similar(edge_set)

t_forest = RootedTree(Vector{T}(undef, order(t)), true)
level_sequence = similar(t_forest.level_sequence)
forest = PartitionForestIterator(t_forest, level_sequence, edge_set_tmp)
PartitionIterator{T, Tree}(t, forest, skeleton, edge_set, edge_set_tmp)
PartitionIterator{T, Tree, typeof(edge_set)}(
t, forest, skeleton, edge_set, edge_set_tmp)
end

# Allocate global buffer for `PartitionIterator` for each thread
Expand All @@ -857,22 +867,22 @@ function PartitionIterator(t::RootedTree{Int, Vector{Int}})
resize!(level_sequence, order_t)
buffer_skeleton = PARTITION_ITERATOR_BUFFER_SKELETON[id]
resize!(buffer_skeleton, order_t)
edge_set = PARTITION_ITERATOR_BUFFER_EDGE_SET[id]
edge_set = PARTITION_ITERATOR_BUFFER_EDGE_SET[id] |> DeleteVector
resize!(edge_set, order_t - 1)
edge_set_tmp = PARTITION_ITERATOR_BUFFER_EDGE_SET_TMP[id]
edge_set_tmp = PARTITION_ITERATOR_BUFFER_EDGE_SET_TMP[id] |> DeleteVector
resize!(edge_set_tmp, order_t - 1)
else
buffer_forest_t = Vector{Int}(undef, order_t)
level_sequence = similar(buffer_forest_t)
buffer_skeleton = similar(buffer_forest_t)
edge_set = Vector{Bool}(undef, order_t - 1)
edge_set = Vector{Bool}(undef, order_t - 1) |> DeleteVector
edge_set_tmp = similar(edge_set)
end

skeleton = RootedTree(buffer_skeleton, true)
t_forest = RootedTree(buffer_forest_t, true)
forest = PartitionForestIterator(t_forest, level_sequence, edge_set_tmp)
PartitionIterator{Int, RootedTree{Int, Vector{Int}}}(
PartitionIterator{Int, RootedTree{Int, Vector{Int}}, typeof(edge_set)}(
t, forest, skeleton, edge_set, edge_set_tmp)
end

Expand Down