Skip to content

Commit 4464550

Browse files
Merge pull request #76 from CarloLucibello/cl/eweight
Support edge weights in GCNConv
2 parents 5d53d05 + ce10af3 commit 4464550

File tree

13 files changed

+300
-74
lines changed

13 files changed

+300
-74
lines changed

.github/workflows/CI.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ on:
66
push:
77
branches:
88
- master
9-
tags: '*'
109
jobs:
1110
test:
1211
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
@@ -45,5 +44,4 @@ jobs:
4544
- uses: julia-actions/julia-processcoverage@v1
4645
- uses: codecov/codecov-action@v1
4746
with:
48-
# token: ${{ secrets.CODECOV_TOKEN }}
4947
file: lcov.info

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2222
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2323
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2424
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
25-
TestEnv = "1e6cf692-eddd-4d53-88a5-2d735e33781b"
2625

2726
[compat]
2827
Adapt = "3"
@@ -39,7 +38,6 @@ NNlib = "0.7"
3938
NNlibCUDA = "0.1"
4039
Reexport = "1"
4140
StatsBase = "0.32, 0.33"
42-
TestEnv = "1"
4341
julia = "1.6"
4442

4543
[extras]

docs/src/api/messagepassing.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,5 @@ propagate
2424
copy_xi
2525
copy_xj
2626
xi_dot_xj
27+
e_mul_xj
2728
```

docs/src/gnngraph.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,25 @@ g.ndata.z
114114
g.edata.e
115115
```
116116

117+
## Edge weights
118+
119+
It is common to denote scalar edge features as edge weights. The `GNNGraph` has specific support
120+
for edge weights: they can be stored as part of internal representions of the graph (COO or adjacency matrix). Some graph convolutional layers, most notably the [`GCNConv`](@ref), can use the edge weights to perform weighted sums over the nodes' neighborhoods.
121+
122+
```julia
123+
julia> source = [1, 1, 2, 2, 3, 3];
124+
125+
julia> target = [2, 3, 1, 3, 1, 2];
126+
127+
julia> weight = [1.0, 0.5, 2.1, 2.3, 4, 4.1];
128+
129+
julia> g = GNNGraph(source, target, weight)
130+
GNNGraph:
131+
num_nodes = 3
132+
num_edges = 6
133+
134+
```
135+
117136
## Batches and Subgraphs
118137

119138
Multiple `GNNGraph`s can be batched togheter into a single graph

src/GNNGraphs/convert.jl

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,9 @@ end
6363

6464
to_dense(A::AbstractSparseMatrix, x...; kws...) = to_dense(collect(A), x...; kws...)
6565

66-
function to_dense(A::ADJMAT_T, T::DataType=eltype(A); dir=:out, num_nodes=nothing)
66+
function to_dense(A::ADJMAT_T, T=nothing; dir=:out, num_nodes=nothing)
6767
@assert dir [:out, :in]
68+
T = T === nothing ? eltype(A) : T
6869
num_nodes = size(A, 1)
6970
@assert num_nodes == size(A, 2)
7071
# @assert all(x -> (x == 1) || (x == 0), A)
@@ -78,11 +79,12 @@ function to_dense(A::ADJMAT_T, T::DataType=eltype(A); dir=:out, num_nodes=nothin
7879
return A, num_nodes, num_edges
7980
end
8081

81-
function to_dense(adj_list::ADJLIST_T, T::DataType=Int; dir=:out, num_nodes=nothing)
82+
function to_dense(adj_list::ADJLIST_T, T=nothing; dir=:out, num_nodes=nothing)
8283
@assert dir [:out, :in]
8384
num_nodes = length(adj_list)
8485
num_edges = sum(length.(adj_list))
8586
@assert num_nodes > 0
87+
T = T === nothing ? eltype(adj_list[1]) : T
8688
A = similar(adj_list[1], T, (num_nodes, num_nodes))
8789
if dir == :out
8890
for (i, neigs) in enumerate(adj_list)
@@ -96,26 +98,28 @@ function to_dense(adj_list::ADJLIST_T, T::DataType=Int; dir=:out, num_nodes=noth
9698
A, num_nodes, num_edges
9799
end
98100

99-
function to_dense(coo::COO_T, T::DataType=Int; dir=:out, num_nodes=nothing)
101+
function to_dense(coo::COO_T, T=nothing; dir=:out, num_nodes=nothing)
100102
# `dir` will be ignored since the input `coo` is always in source -> target format.
101103
# The output will always be a adjmat in :out format (e.g. A[i,j] denotes from i to j)
102104
s, t, val = coo
103105
n = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes
106+
val = isnothing(val) ? eltype(s)(1) : val
107+
T = T === nothing ? eltype(val) : T
104108
A = fill!(similar(s, T, (n, n)), 0)
105-
if isnothing(val)
106-
A[s .+ n .* (t .- 1)] .= 1 # exploiting linear indexing
107-
else
108-
A[s .+ n .* (t .- 1)] .= val # exploiting linear indexing
109-
end
109+
v = vec(A)
110+
idxs = s .+ n .* (t .- 1)
111+
NNlib.scatter!(+, v, val, idxs)
112+
# A[s .+ n .* (t .- 1)] .= val # exploiting linear indexing
110113
return A, n, length(s)
111114
end
112115

113116
### SPARSE #############
114117

115-
function to_sparse(A::ADJMAT_T, T::DataType=eltype(A); dir=:out, num_nodes=nothing)
118+
function to_sparse(A::ADJMAT_T, T=nothing; dir=:out, num_nodes=nothing)
116119
@assert dir [:out, :in]
117120
num_nodes = size(A, 1)
118121
@assert num_nodes == size(A, 2)
122+
T = T === nothing ? eltype(A) : T
119123
num_edges = A isa AbstractSparseMatrix ? nnz(A) : count(!=(0), A)
120124
if dir == :in
121125
A = A'
@@ -129,13 +133,14 @@ function to_sparse(A::ADJMAT_T, T::DataType=eltype(A); dir=:out, num_nodes=nothi
129133
return A, num_nodes, num_edges
130134
end
131135

132-
function to_sparse(adj_list::ADJLIST_T, T::DataType=Int; dir=:out, num_nodes=nothing)
136+
function to_sparse(adj_list::ADJLIST_T, T=nothing; dir=:out, num_nodes=nothing)
133137
coo, num_nodes, num_edges = to_coo(adj_list; dir, num_nodes)
134138
return to_sparse(coo; dir, num_nodes)
135139
end
136140

137-
function to_sparse(coo::COO_T, T::DataType=Int; dir=:out, num_nodes=nothing)
141+
function to_sparse(coo::COO_T, T=nothing; dir=:out, num_nodes=nothing)
138142
s, t, eweight = coo
143+
T = T === nothing ? eltype(s) : T
139144
eweight = isnothing(eweight) ? fill!(similar(s, T), 1) : eweight
140145
num_nodes = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes
141146
A = sparse(s, t, eweight, num_nodes, num_nodes)

src/GNNGraphs/query.jl

Lines changed: 82 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,30 @@ edge_index(g::GNNGraph{<:COO_T}) = g.graph[1:2]
1313

1414
edge_index(g::GNNGraph{<:ADJMAT_T}) = to_coo(g.graph, num_nodes=g.num_nodes)[1][1:2]
1515

16-
edge_weight(g::GNNGraph{<:COO_T}) = g.graph[3]
16+
get_edge_weight(g::GNNGraph{<:COO_T}) = g.graph[3]
1717

18-
edge_weight(g::GNNGraph{<:ADJMAT_T}) = to_coo(g.graph, num_nodes=g.num_nodes)[1][3]
18+
get_edge_weight(g::GNNGraph{<:ADJMAT_T}) = to_coo(g.graph, num_nodes=g.num_nodes)[1][3]
1919

2020
Graphs.edges(g::GNNGraph) = zip(edge_index(g)...)
2121

2222
Graphs.edgetype(g::GNNGraph) = Tuple{Int, Int}
23+
nodetype(g::GNNGraph) = Base.eltype(g)
24+
25+
"""
26+
nodetype(g::GNNGraph)
27+
28+
Type of nodes in `g`,
29+
an integer type like `Int`, `Int32`, `Uint16`, ....
30+
"""
31+
function nodetype(g::GNNGraph{<:COO_T}, T=nothing)
32+
s, t = edge_index(g)
33+
return eltype(s)
34+
end
35+
36+
function nodetype(g::GNNGraph{<:ADJMAT_T}, T=nothing)
37+
T !== nothing && return T
38+
return eltype(g.graph)
39+
end
2340

2441
function Graphs.has_edge(g::GNNGraph{<:COO_T}, i::Integer, j::Integer)
2542
s, t = edge_index(g)
@@ -77,7 +94,7 @@ function adjacency_list(g::GNNGraph; dir=:out)
7794
return [fneighs(g, i) for i in 1:g.num_nodes]
7895
end
7996

80-
function Graphs.adjacency_matrix(g::GNNGraph{<:COO_T}, T::DataType=Int; dir=:out)
97+
function Graphs.adjacency_matrix(g::GNNGraph{<:COO_T}, T::DataType=nodetype(g); dir=:out)
8198
if g.graph[1] isa CuVector
8299
# TODO revisit after https://github.com/JuliaGPU/CUDA.jl/pull/1152
83100
A, n, m = to_dense(g.graph, T, num_nodes=g.num_nodes)
@@ -88,34 +105,85 @@ function Graphs.adjacency_matrix(g::GNNGraph{<:COO_T}, T::DataType=Int; dir=:out
88105
return dir == :out ? A : A'
89106
end
90107

91-
function Graphs.adjacency_matrix(g::GNNGraph{<:ADJMAT_T}, T::DataType=eltype(g.graph); dir=:out)
108+
function Graphs.adjacency_matrix(g::GNNGraph{<:ADJMAT_T}, T::DataType=nodetype(g); dir=:out)
92109
@assert dir [:in, :out]
93110
A = g.graph
94111
A = T != eltype(A) ? T.(A) : A
95112
return dir == :out ? A : A'
96113
end
97114

98-
function Graphs.degree(g::GNNGraph{<:COO_T}, T=nothing; dir=:out)
115+
function _get_edge_weight(g, edge_weight)
116+
if edge_weight === true || edge_weight === nothing
117+
ew = get_edge_weight(g)
118+
elseif edge_weight === false
119+
ew = nothing
120+
elseif edge_weight isa AbstractVector
121+
ew = edge_weight
122+
else
123+
error("Invalid edge_weight argument.")
124+
end
125+
return ew
126+
end
127+
128+
"""
129+
degree(g::GNNGraph, T=nothing; dir=:out, edge_weight=true)
130+
131+
Return a vector containing the degrees of the nodes in `g`.
132+
133+
# Arguments
134+
- `g`: A graph.
135+
- `T`: Element type of the returned vector. If `nothing`, is
136+
chosen based on the graph type and will be an integer
137+
if `edge_weight=false`.
138+
- `dir`: For `dir=:out` the degree of a node is counted based on the outgoing edges.
139+
For `dir=:in`, the ingoing edges are used. If `dir=:both` we have the sum of the two.
140+
- `edge_weight`: If `true` and the graph contains weighted edges, the degree will
141+
be weighted. Set to `false` instead to just count the number of
142+
outgoing/ingoing edges.
143+
In alternative, you can also pass a vector of weights to be used
144+
instead of the graph's own weights.
145+
"""
146+
function Graphs.degree(g::GNNGraph{<:COO_T}, T=nothing; dir=:out, edge_weight=true)
99147
s, t = edge_index(g)
100-
T = isnothing(T) ? eltype(s) : T
148+
149+
edge_weight = _get_edge_weight(g, edge_weight)
150+
edge_weight = edge_weight === nothing ? eltype(s)(1) : edge_weight
151+
152+
T = isnothing(T) ? eltype(edge_weight) : T
101153
degs = fill!(similar(s, T, g.num_nodes), 0)
102-
src = 1
103154
if dir [:out, :both]
104-
NNlib.scatter!(+, degs, src, s)
155+
NNlib.scatter!(+, degs, edge_weight, s)
105156
end
106157
if dir [:in, :both]
107-
NNlib.scatter!(+, degs, src, t)
158+
NNlib.scatter!(+, degs, edge_weight, t)
108159
end
109160
return degs
110161
end
111162

112-
function Graphs.degree(g::GNNGraph{<:ADJMAT_T}, T=Int; dir=:out)
113-
@assert dir (:in, :out)
114-
A = adjacency_matrix(g, T)
115-
return dir == :out ? vec(sum(A, dims=2)) : vec(sum(A, dims=1))
163+
function Graphs.degree(g::GNNGraph{<:ADJMAT_T}, T=nothing; dir=:out, edge_weight=true)
164+
# edge_weight=true or edge_weight=nothing act the same here
165+
@assert !(edge_weight isa AbstractArray) "passing the edge weights is not support by adjacency matrix representations"
166+
@assert dir (:in, :out, :both)
167+
if T === nothing
168+
Nt = nodetype(g)
169+
if edge_weight === false && !(Nt <: Integer)
170+
T = Nt == Float32 ? Int32 :
171+
Nt == Float16 ? Int16 : Int
172+
else
173+
T = Nt
174+
end
175+
end
176+
A = adjacency_matrix(g)
177+
if edge_weight === false
178+
A = map(>(0), A)
179+
end
180+
A = eltype(A) != T ? T.(A) : A
181+
return dir == :out ? vec(sum(A, dims=2)) :
182+
dir == :in ? vec(sum(A, dims=1)) :
183+
vec(sum(A, dims=1)) .+ vec(sum(A, dims=2))
116184
end
117185

118-
function Graphs.laplacian_matrix(g::GNNGraph, T::DataType=Int; dir::Symbol=:out)
186+
function Graphs.laplacian_matrix(g::GNNGraph, T::DataType=nodetype(g); dir::Symbol=:out)
119187
A = adjacency_matrix(g, T; dir=dir)
120188
D = Diagonal(vec(sum(A; dims=2)))
121189
return D - A

src/GNNGraphs/transform.jl

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,23 @@
55
Return a graph with the same features as `g`
66
but also adding edges connecting the nodes to themselves.
77
8-
Nodes with already existing
9-
self-loops will obtain a second self-loop.
8+
Nodes with already existing self-loops will obtain a second self-loop.
9+
10+
If the graphs has edge weights, the new edges will have weight 1.
1011
"""
1112
function add_self_loops(g::GNNGraph{<:COO_T})
1213
s, t = edge_index(g)
1314
@assert g.edata === (;)
14-
@assert edge_weight(g) === nothing
15+
ew = get_edge_weight(g)
1516
n = g.num_nodes
1617
nodes = convert(typeof(s), [1:n;])
1718
s = [s; nodes]
1819
t = [t; nodes]
20+
if ew !== nothing
21+
ew = [ew; fill!(similar(ew, n), 1)]
22+
end
1923

20-
GNNGraph((s, t, nothing),
24+
GNNGraph((s, t, ew),
2125
g.num_nodes, length(s), g.num_graphs,
2226
g.graph_indicator,
2327
g.ndata, g.edata, g.gdata)
@@ -39,7 +43,7 @@ function remove_self_loops(g::GNNGraph{<:COO_T})
3943
s, t = edge_index(g)
4044
# TODO remove these constraints
4145
@assert g.edata === (;)
42-
@assert edge_weight(g) === nothing
46+
@assert get_edge_weight(g) === nothing
4347

4448
mask_old_loops = s .!= t
4549
s = s[mask_old_loops]
@@ -61,7 +65,7 @@ function remove_multi_edges(g::GNNGraph{<:COO_T})
6165
# TODO remove these constraints
6266
@assert g.num_graphs == 1
6367
@assert g.edata === (;)
64-
@assert edge_weight(g) === nothing
68+
@assert get_edge_weight(g) === nothing
6569

6670
idxs, idxmax = edge_encoding(s, t, g.num_nodes)
6771
union!(idxs)
@@ -85,7 +89,7 @@ function add_edges(g::GNNGraph{<:COO_T},
8589

8690
@assert length(snew) == length(tnew)
8791
# TODO remove this constraint
88-
@assert edge_weight(g) === nothing
92+
@assert get_edge_weight(g) === nothing
8993

9094
edata = normalize_graphdata(edata, default_name=:e, n=length(snew))
9195
edata = cat_features(g.edata, edata)
@@ -126,7 +130,7 @@ function SparseArrays.blockdiag(g1::GNNGraph, g2::GNNGraph)
126130
s2, t2 = edge_index(g2)
127131
s = vcat(s1, nv1 .+ s2)
128132
t = vcat(t1, nv1 .+ t2)
129-
w = cat_features(edge_weight(g1), edge_weight(g2))
133+
w = cat_features(get_edge_weight(g1), get_edge_weight(g2))
130134
graph = (s, t, w)
131135
ind1 = isnothing(g1.graph_indicator) ? ones_like(s1, Int, nv1) : g1.graph_indicator
132136
ind2 = isnothing(g2.graph_indicator) ? ones_like(s2, Int, nv2) : g2.graph_indicator
@@ -288,7 +292,7 @@ function getgraph(g::GNNGraph, i::AbstractVector{Int}; nmap=false)
288292
graph_indicator = [graphmap[i] for i in g.graph_indicator[node_mask]]
289293

290294
s, t = edge_index(g)
291-
w = edge_weight(g)
295+
w = get_edge_weight(g)
292296
edge_mask = s .∈ Ref(nodes)
293297

294298
if g.graph isa COO_T
@@ -340,7 +344,7 @@ function negative_sample(g::GNNGraph;
340344
@assert g.num_graphs == 1
341345
# Consider self-loops as positive edges
342346
# Construct new graph dropping features
343-
g = add_self_loops(GNNGraph(edge_index(g)))
347+
g = add_self_loops(GNNGraph(edge_index(g), num_nodes=g.num_nodes))
344348

345349
s, t = edge_index(g)
346350
n = g.num_nodes

0 commit comments

Comments
 (0)