diff --git a/src/Graphs.jl b/src/Graphs.jl index 86bc0946..bb99d30e 100644 --- a/src/Graphs.jl +++ b/src/Graphs.jl @@ -253,6 +253,7 @@ export has_negative_edge_cycle_spfa, has_negative_edge_cycle, enumerate_paths, + enumerate_paths!, johnson_shortest_paths, floyd_warshall_shortest_paths, transitiveclosure!, diff --git a/src/shortestpaths/bellman-ford.jl b/src/shortestpaths/bellman-ford.jl index 875d5937..c3a803c2 100644 --- a/src/shortestpaths/bellman-ford.jl +++ b/src/shortestpaths/bellman-ford.jl @@ -107,6 +107,7 @@ single destinations, the path is represented by a single vector of vertices, and will be length 0 if the path does not exist. ### Implementation Notes + For Floyd-Warshall path states, please note that the output is a bit different, since this algorithm calculates all shortest paths for all pairs of vertices: `enumerate_paths(state)` will return a vector (indexed by source vertex) of @@ -116,13 +117,47 @@ to all other vertices. In addition, `enumerate_paths(state, v, d)` will return a vector representing the path from vertex `v` to vertex `d`. """ function enumerate_paths(state::AbstractPathState, vs::AbstractVector{<:Integer}) - parents = state.parents - T = eltype(parents) + T = eltype(state.parents) + all_paths = Vector{T}[Vector{eltype(state.parents)}() for _ in 1:length(vs)] + return enumerate_paths!(all_paths, state, vs) +end +enumerate_paths(state::AbstractPathState, v::Integer) = enumerate_paths(state, v:v)[1] +function enumerate_paths(state::AbstractPathState) + return enumerate_paths(state, 1:length(state.parents)) +end + +""" + enumerate_paths!(paths::AbstractVector{<:AbstractVector}, state, vs::AbstractVector{Int}) + +In-place version of [`enumerate_paths`](@ref). + +`paths` must be a `Vector{Vectors{eltype(state.parents)}}`, `state` an `AbstractPathState`, +and `vs`` an AbstractRange or other AbstractVector of `Int`. +See the `enumerate_paths` documentation for details. + +`enumerate_paths!` should be more efficient when used in a loop, +as the same memory can be used for each iteration. +""" +function enumerate_paths!( + all_paths::AbstractVector{<:AbstractVector}, + state::AbstractPathState, + vs::AbstractVector{<:Integer}, +) + Base.require_one_based_indexing(all_paths) + Base.require_one_based_indexing(vs) + length(all_paths) == length(vs) || throw( + ArgumentError( + "length of destination paths $(length(vs)) deos not match length of vs $(length(all_paths))", + ), + ) + + parents = state.parents + T = eltype(state.parents) num_vs = length(vs) - all_paths = Vector{Vector{T}}(undef, num_vs) + for i in 1:num_vs - all_paths[i] = Vector{T}() + empty!(all_paths[i]) index = T(vs[i]) if parents[index] != 0 || parents[index] == index while parents[index] != 0 @@ -135,8 +170,3 @@ function enumerate_paths(state::AbstractPathState, vs::AbstractVector{<:Integer} end return all_paths end - -enumerate_paths(state::AbstractPathState, v::Integer) = enumerate_paths(state, [v])[1] -function enumerate_paths(state::AbstractPathState) - return enumerate_paths(state, [1:length(state.parents);]) -end diff --git a/test/shortestpaths/bellman-ford.jl b/test/shortestpaths/bellman-ford.jl index 77d8bac0..f845db48 100644 --- a/test/shortestpaths/bellman-ford.jl +++ b/test/shortestpaths/bellman-ford.jl @@ -62,6 +62,8 @@ @test getfield.(y.dists, :val) == getfield.(z.dists, :val) == [Inf, 0, 6, 17, 33] @test @inferred(enumerate_paths(z))[2] == [] @test @inferred(enumerate_paths(z))[4] == enumerate_paths(z, 4) == [2, 3, 4] + @test @inferred(enumerate_paths!([[0]], z, 4:4))[1] == [2, 3, 4] + @test_throws ArgumentError enumerate_paths!([[0, 0], [0, 0]], z, 4:4) @test @inferred(!has_negative_edge_cycle(g)) @test @inferred(!has_negative_edge_cycle(g, d3))