Skip to content

Commit acf4b6a

Browse files
Coloring refinement algorithm (#444)
* add coloring refinment algorithm * also in GNNlib * docs
1 parent 3bcafbe commit acf4b6a

File tree

6 files changed

+147
-4
lines changed

6 files changed

+147
-4
lines changed

GNNlib/src/GNNGraphs/GNNGraphs.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ include("operators.jl")
101101

102102
include("convert.jl")
103103
include("utils.jl")
104+
export sort_edge_index,
105+
color_refinement
104106

105107
include("gatherscatter.jl")
106108
# _gather, _scatter

GNNlib/src/GNNGraphs/utils.jl

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,15 @@ end
4949

5050
sort_edge_index(eindex::Tuple) = sort_edge_index(eindex...)
5151

52+
"""
53+
sort_edge_index(ei::Tuple) -> u', v'
54+
sort_edge_index(u, v) -> u', v'
55+
56+
Return a sorted version of the tuple of vectors `ei = (u, v)`,
57+
applying a common permutation to `u` and `v`.
58+
The sorting is lexycographic, that is the pairs `(ui, vi)`
59+
are sorted first according to the `ui` and then according to `vi`.
60+
"""
5261
function sort_edge_index(u, v)
5362
uv = collect(zip(u, v))
5463
p = sortperm(uv) # isless lexicographically defined for tuples
@@ -301,4 +310,56 @@ end
301310
@non_differentiable normalize_graphdata(::Nothing)
302311

303312
iscuarray(x::AbstractArray) = false
304-
@non_differentiable iscuarray(::Any)
313+
@non_differentiable iscuarray(::Any)
314+
315+
316+
@doc raw"""
317+
color_refinement(g::GNNGraph, [x0]) -> x, num_colors, niters
318+
319+
The color refinement algorithm for graph coloring.
320+
Given a graph `g` and an initial coloring `x0`, the algorithm
321+
iteratively refines the coloring until a fixed point is reached.
322+
323+
At each iteration the algorithm computes a hash of the coloring and the sorted list of colors
324+
of the neighbors of each node. This hash is used to determine if the coloring has changed.
325+
326+
```math
327+
x_i' = hashmap((x_i, sort([x_j for j \in N(i)]))).
328+
````
329+
330+
This algorithm is related to the 1-Weisfeiler-Lehman algorithm for graph isomorphism testing.
331+
332+
# Arguments
333+
- `g::GNNGraph`: The graph to color.
334+
- `x0::AbstractVector{<:Integer}`: The initial coloring. If not provided, all nodes are colored with 1.
335+
336+
# Returns
337+
- `x::AbstractVector{<:Integer}`: The final coloring.
338+
- `num_colors::Int`: The number of colors used.
339+
- `niters::Int`: The number of iterations until convergence.
340+
"""
341+
color_refinement(g::GNNGraph) = color_refinement(g, ones(Int, g.num_nodes))
342+
343+
function color_refinement(g::GNNGraph, x0::AbstractVector{<:Integer})
344+
@assert length(x0) == g.num_nodes
345+
s, t = edge_index(g)
346+
t, s = sort_edge_index(t, s) # sort by target
347+
degs = degree(g, dir=:in)
348+
x = x0
349+
350+
hashmap = Dict{UInt64, Int}()
351+
x′ = zeros(Int, length(x0))
352+
niters = 0
353+
while true
354+
xneigs = chunk(x[s], size=degs)
355+
for (i, (xi, xineigs)) in enumerate(zip(x, xneigs))
356+
idx = hash((xi, sort(xineigs)))
357+
x′[i] = get!(hashmap, idx, length(hashmap) + 1)
358+
end
359+
niters += 1
360+
x == x′ && break
361+
x = x′
362+
end
363+
num_colors = length(union(x))
364+
return x, num_colors, niters
365+
end

docs/src/api/gnngraph.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,13 @@ Pages = ["transform.jl"]
5252
Private = false
5353
```
5454

55+
## Utils
56+
57+
```@docs
58+
GNNGraphs.sort_edge_index
59+
GNNGraphs.color_refinement
60+
```
61+
5562
## Generate
5663

5764
```@autodocs

src/GNNGraphs/GNNGraphs.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import KrylovKit
1414
using ChainRulesCore
1515
using LinearAlgebra, Random, Statistics
1616
import MLUtils
17-
using MLUtils: getobs, numobs, ones_like, zeros_like
17+
using MLUtils: getobs, numobs, ones_like, zeros_like, chunk
1818
import Functors
1919

2020
include("chainrules.jl") # hacks for differentiability
@@ -104,6 +104,7 @@ include("operators.jl")
104104

105105
include("convert.jl")
106106
include("utils.jl")
107+
export sort_edge_index, color_refinement
107108

108109
include("gatherscatter.jl")
109110
# _gather, _scatter

src/GNNGraphs/utils.jl

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,22 @@ end
4949

5050
sort_edge_index(eindex::Tuple) = sort_edge_index(eindex...)
5151

52+
"""
53+
sort_edge_index(ei::Tuple) -> u', v'
54+
sort_edge_index(u, v) -> u', v'
55+
56+
Return a sorted version of the tuple of vectors `ei = (u, v)`,
57+
applying a common permutation to `u` and `v`.
58+
The sorting is lexycographic, that is the pairs `(ui, vi)`
59+
are sorted first according to the `ui` and then according to `vi`.
60+
"""
5261
function sort_edge_index(u, v)
5362
uv = collect(zip(u, v))
5463
p = sortperm(uv) # isless lexicographically defined for tuples
5564
return u[p], v[p]
5665
end
5766

5867

59-
6068
cat_features(x1::Nothing, x2::Nothing) = nothing
6169
cat_features(x1::AbstractArray, x2::AbstractArray) = cat(x1, x2, dims = ndims(x1))
6270
function cat_features(x1::Union{Number, AbstractVector}, x2::Union{Number, AbstractVector})
@@ -301,4 +309,56 @@ end
301309
@non_differentiable normalize_graphdata(::Nothing)
302310

303311
iscuarray(x::AbstractArray) = false
304-
@non_differentiable iscuarray(::Any)
312+
@non_differentiable iscuarray(::Any)
313+
314+
315+
@doc raw"""
316+
color_refinement(g::GNNGraph, [x0]) -> x, num_colors, niters
317+
318+
The color refinement algorithm for graph coloring.
319+
Given a graph `g` and an initial coloring `x0`, the algorithm
320+
iteratively refines the coloring until a fixed point is reached.
321+
322+
At each iteration the algorithm computes a hash of the coloring and the sorted list of colors
323+
of the neighbors of each node. This hash is used to determine if the coloring has changed.
324+
325+
```math
326+
x_i' = hashmap((x_i, sort([x_j for j \in N(i)]))).
327+
````
328+
329+
This algorithm is related to the 1-Weisfeiler-Lehman algorithm for graph isomorphism testing.
330+
331+
# Arguments
332+
- `g::GNNGraph`: The graph to color.
333+
- `x0::AbstractVector{<:Integer}`: The initial coloring. If not provided, all nodes are colored with 1.
334+
335+
# Returns
336+
- `x::AbstractVector{<:Integer}`: The final coloring.
337+
- `num_colors::Int`: The number of colors used.
338+
- `niters::Int`: The number of iterations until convergence.
339+
"""
340+
color_refinement(g::GNNGraph) = color_refinement(g, ones(Int, g.num_nodes))
341+
342+
function color_refinement(g::GNNGraph, x0::AbstractVector{<:Integer})
343+
@assert length(x0) == g.num_nodes
344+
s, t = edge_index(g)
345+
t, s = sort_edge_index(t, s) # sort by target
346+
degs = degree(g, dir=:in)
347+
x = x0
348+
349+
hashmap = Dict{UInt64, Int}()
350+
x′ = zeros(Int, length(x0))
351+
niters = 0
352+
while true
353+
xneigs = chunk(x[s], size=degs)
354+
for (i, (xi, xineigs)) in enumerate(zip(x, xneigs))
355+
idx = hash((xi, sort(xineigs)))
356+
x′[i] = get!(hashmap, idx, length(hashmap) + 1)
357+
end
358+
niters += 1
359+
x == x′ && break
360+
x = x′
361+
end
362+
num_colors = length(union(x))
363+
return x, num_colors, niters
364+
end

test/GNNGraphs/utils.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,15 @@
4848
@test sdec == snew
4949
@test tdec == tnew
5050
end
51+
52+
@testset "color_refinment" begin
53+
g = rand_graph(10, 20, seed=17, graph_type = GRAPH_T)
54+
x0 = ones(Int, 10)
55+
x, ncolors, niters = color_refinement(g, x0)
56+
@test ncolors == 8
57+
@test niters == 2
58+
@test x == [4, 5, 6, 7, 8, 5, 8, 9, 10, 11]
59+
60+
x2, _, _ = color_refinement(g)
61+
@test x2 == x
62+
end

0 commit comments

Comments
 (0)