Skip to content

Commit af57d22

Browse files
Allow passing random number generator to label-propagation (#95)
This also fixes an issue where the randomized tests for label-propagation would sporadically fail.
1 parent 4c35a89 commit af57d22

File tree

5 files changed

+50
-12
lines changed

5 files changed

+50
-12
lines changed

src/community/label_propagation.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
label_propagation(g, maxiter=1000)
2+
label_propagation(g, maxiter=1000; rng=GLOBAL_RNG)
33
44
Community detection using the label propagation algorithm.
55
Return two vectors: the first is the label number assigned to each node, and
@@ -9,7 +9,10 @@ the second is the convergence history for each node. Will return after
99
### References
1010
- [Raghavan et al.](http://arxiv.org/abs/0709.2938)
1111
"""
12-
function label_propagation(g::AbstractGraph{T}, maxiter=1000) where T
12+
function label_propagation(g::AbstractGraph{T}, maxiter=1000; rng::Union{Nothing, AbstractRNG} = nothing, seed::Union{Nothing, Integer} = nothing) where T
13+
14+
rng = rng_from_rng_or_seed(rng, seed)
15+
1316
n = nv(g)
1417
n == 0 && return (T[], Int[])
1518

@@ -28,11 +31,11 @@ function label_propagation(g::AbstractGraph{T}, maxiter=1000) where T
2831
for (j, node) in enumerate(active_vs)
2932
random_order[j] = node
3033
end
31-
range_shuffle!(1:num_active, random_order)
34+
range_shuffle!(rng, 1:num_active, random_order)
3235
@inbounds for j = 1:num_active
3336
u = random_order[j]
3437
old_comm = label[u]
35-
label[u] = vote!(g, label, c, u)
38+
label[u] = vote!(rng, g, label, c, u)
3639
if old_comm != label[u]
3740
for v in outneighbors(g, u)
3841
push!(active_vs, v)
@@ -59,13 +62,11 @@ mutable struct NeighComm{T<:Integer}
5962
end
6063

6164
"""
62-
range_shuffle!(r, a; seed=-1)
65+
range_shuffle!(rng, r, a)
6366
6467
Fast shuffle Array `a` in UnitRange `r`.
65-
Uses `seed` to initialize the random number generator, defaulting to `Random.GLOBAL_RNG` for `seed=-1`.
6668
"""
67-
function range_shuffle!(r::UnitRange, a::AbstractVector; seed::Int=-1)
68-
rng = getRNG(seed)
69+
function range_shuffle!(rng::AbstractRNG, r::UnitRange, a::AbstractVector)
6970
(r.start > 0 && r.stop <= length(a)) || throw(DomainError(r, "range indices are out of bounds"))
7071
@inbounds for i = length(r):-1:2
7172
j = rand(rng, 1:i)
@@ -76,11 +77,11 @@ function range_shuffle!(r::UnitRange, a::AbstractVector; seed::Int=-1)
7677
end
7778

7879
"""
79-
vote!(g, m, c, u)
80+
vote!(rng, g, m, c, u)
8081
8182
Return the label with greatest frequency.
8283
"""
83-
function vote!(g::AbstractGraph, m::Vector, c::NeighComm, u::Integer)
84+
function vote!(rng::AbstractRNG, g::AbstractGraph, m::Vector, c::NeighComm, u::Integer)
8485
@inbounds for i = 1:c.neigh_last - 1
8586
c.neigh_cnt[c.neigh_pos[i]] = -1
8687
end
@@ -102,7 +103,7 @@ function vote!(g::AbstractGraph, m::Vector, c::NeighComm, u::Integer)
102103
end
103104
end
104105
# ties breaking randomly
105-
range_shuffle!(1:c.neigh_last - 1, c.neigh_pos)
106+
range_shuffle!(rng, 1:c.neigh_last - 1, c.neigh_pos)
106107

107108
result_lbl = zero(eltype(c.neigh_pos))
108109
for lbl in c.neigh_pos

src/utils.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,27 @@ sample(a::AbstractVector, k::Integer; exclude=()) = sample!(getRNG(), collect(a)
4444

4545
getRNG(seed::Integer=-1) = seed >= 0 ? MersenneTwister(seed) : GLOBAL_RNG
4646

47+
"""
48+
rng_from_rng_or_seed(rng, seed)
49+
50+
Helper function for randomized functions that can take a random generator as well as a seed argument.
51+
52+
Currently most randomized functions in this package take a seed integer as an argument.
53+
As modern randomized Julia functions tend to take a random generator instead of a seed,
54+
this function helps with the transition by taking `rng` and `seed` as an argument and
55+
always returning a random number generator.
56+
At least one of these arguments must be `nothing`.
57+
"""
58+
function rng_from_rng_or_seed(rng::Union{Nothing, AbstractRNG}, seed::Union{Nothing, Integer})
59+
60+
# TODO at some point we might emit a deprecation warning if a seed is specified
61+
62+
!(isnothing(seed) || isnothing(rng)) && throw(ArgumentError("Cannot specify both, seed and rng"))
63+
!isnothing(seed) && return getRNG(seed)
64+
isnothing(rng) && return GLOBAL_RNG
65+
return rng
66+
end
67+
4768
"""
4869
insorted(item, collection)
4970

test/community/label_propagation.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
@testset "Label propagation" begin
2+
3+
rng = MersenneTwister(1234)
4+
25
n = 10
36
g10 = complete_graph(n)
47
for g in testgraphs(g10)
58
z = copy(g)
69
for k = 2:5
710
z = blockdiag(z, g)
811
add_edge!(z, (k - 1) * n, k * n)
9-
c, ch = @inferred(label_propagation(z))
12+
c, ch = @inferred(label_propagation(z; rng=rng))
1013
a = collect(n:n:(k * n))
1114
a = Int[div(i - 1, n) + 1 for i = 1:(k * n)]
1215
# check the number of communities

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using Graphs.Experimental
44
using Test
55
using SparseArrays
66
using LinearAlgebra
7+
using Compat
78
using DelimitedFiles
89
using Base64
910
using Random

test/utils.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,18 @@
3333
end
3434
end
3535

36+
@testset "rng_from_rng_or_seed" begin
37+
@test Graphs.rng_from_rng_or_seed(nothing, nothing) === Random.GLOBAL_RNG
38+
@test Graphs.rng_from_rng_or_seed(nothing, -10) === Random.GLOBAL_RNG
39+
@test Graphs.rng_from_rng_or_seed(nothing, 456) == Graphs.getRNG(456)
40+
@compat if ismutable(Random.GLOBAL_RNG)
41+
@test Graphs.rng_from_rng_or_seed(nothing, 456) !== Random.GLOBAL_RNG
42+
end
43+
rng = Random.MersenneTwister(789)
44+
@test Graphs.rng_from_rng_or_seed(rng, nothing) === rng
45+
@test_throws ArgumentError Graphs.rng_from_rng_or_seed(rng, -1)
46+
end
47+
3648
A = [false, true, false, false, true, true]
3749
@test findall(A) == Graphs.findall!(A, Vector{Int16}(undef, 6))[1:3]
3850
end

0 commit comments

Comments
 (0)