Skip to content

Commit e9dff58

Browse files
fused e_mul_xj and weighted option for adjacency_matrix (#107)
* weighted option for adjacency_matrix * fused e_mul_xj
1 parent a53fd27 commit e9dff58

File tree

8 files changed

+228
-28
lines changed

8 files changed

+228
-28
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GraphNeuralNetworks"
22
uuid = "cffab07f-9bc2-4db1-8861-388f63bf7694"
33
authors = ["Carlo Lucibello and contributors"]
4-
version = "0.3.9"
4+
version = "0.3.10"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

perf/bench_gnn.jl

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
using SparseArrays
2+
using GraphNeuralNetworks
3+
using BenchmarkTools
4+
import Random: seed!
5+
using LinearAlgebra
6+
7+
n = 1024
8+
seed!(0)
9+
A = sprand(n, n, 0.01)
10+
b = rand(1, n)
11+
B = rand(100, n)
12+
13+
g = GNNGraph(
14+
A,
15+
ndata=(; b=b, B=B),
16+
edata=(; A=reshape(A.nzval, 1, :)),
17+
graph_type=:coo # changing to :sparse has little effect on performance
18+
)
19+
20+
function spmv(g)
21+
propagate(
22+
(xi, xj, e) -> e .* xj , # same as e_mul_xj
23+
g, +; xj=g.ndata.b, e=g.edata.A
24+
)
25+
end
26+
27+
function spmm1(g)
28+
propagate(
29+
(xi, xj, e) -> e .* xj , # same as e_mul_xj
30+
g, +; xj=g.ndata.B, e=g.edata.A
31+
)
32+
end
33+
function spmm2(g)
34+
propagate(
35+
e_mul_xj,
36+
g, +; xj=g.ndata.B, e=vec(g.edata.A)
37+
)
38+
end
39+
40+
# @assert isequal(spmv(g), b * A) # true
41+
# @btime spmv(g) # ~5 ms
42+
# @btime b * A # ~32 us
43+
44+
@assert isequal(spmm1(g), B * A) # true
45+
@assert isequal(spmm2(g), B * A) # true
46+
@btime spmm1(g) # ~9 ms
47+
@btime spmm2(g) # ~9 ms
48+
@btime B * A # ~400 us
49+
50+
51+
function spmm_copyxj_fused(g)
52+
propagate(
53+
copy_xj,
54+
g, +; xj=g.ndata.B
55+
)
56+
end
57+
58+
function spmm_copyxj_unfused(g)
59+
propagate(
60+
(xi, xj, e) -> xj,
61+
g, +; xj=g.ndata.B
62+
)
63+
end
64+
65+
Adj = map(x -> x > 0 ? 1 : 0, A)
66+
@assert spmm_copyxj_unfused(g) B * Adj
67+
@assert spmm_copyxj_fused(g) B * Adj # bug fixed in #107
68+
69+
@btime spmm_copyxj_fused(g) # 268.614 μs (22 allocations: 1.13 MiB)
70+
@btime spmm_copyxj_unfused(g) # 4.263 ms (52855 allocations: 12.23 MiB)
71+
@btime B * Adj # 196.135 μs (2 allocations: 800.05 KiB)
72+
73+
println()

src/GNNGraphs/convert.jl

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
### CONVERT_TO_COO REPRESENTATION ########
22

3-
function to_coo(coo::COO_T; dir=:out, num_nodes=nothing)
3+
function to_coo(coo::COO_T; dir=:out, num_nodes=nothing, weighted=true)
44
s, t, val = coo
55
num_nodes = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes
66
@assert isnothing(val) || length(val) == length(s)
@@ -10,21 +10,26 @@ function to_coo(coo::COO_T; dir=:out, num_nodes=nothing)
1010
@assert max(maximum(s), maximum(t)) <= num_nodes
1111
end
1212
num_edges = length(s)
13+
if !weighted
14+
coo = (s, t, nothing)
15+
end
1316
return coo, num_nodes, num_edges
1417
end
1518

16-
function to_coo(A::SPARSE_T; dir=:out, num_nodes=nothing)
19+
function to_coo(A::SPARSE_T; dir=:out, num_nodes=nothing, weighted=true)
1720
s, t, v = findnz(A)
1821
if dir == :in
1922
s, t = t, s
2023
end
2124
num_nodes = isnothing(num_nodes) ? size(A, 1) : num_nodes
2225
num_edges = length(s)
23-
26+
if !weighted
27+
v = nothing
28+
end
2429
return (s, t, v), num_nodes, num_edges
2530
end
2631

27-
function to_coo(A::ADJMAT_T; dir=:out, num_nodes=nothing)
32+
function to_coo(A::ADJMAT_T; dir=:out, num_nodes=nothing, weighted=true)
2833
nz = findall(!=(0), A) # vec of cartesian indexes
2934
s, t = ntuple(i -> map(t->t[i], nz), 2)
3035
v = A[nz]
@@ -33,10 +38,13 @@ function to_coo(A::ADJMAT_T; dir=:out, num_nodes=nothing)
3338
end
3439
num_nodes = isnothing(num_nodes) ? size(A, 1) : num_nodes
3540
num_edges = length(s)
41+
if !weighted
42+
v = nothing
43+
end
3644
return (s, t, v), num_nodes, num_edges
3745
end
3846

39-
function to_coo(adj_list::ADJLIST_T; dir=:out, num_nodes=nothing)
47+
function to_coo(adj_list::ADJLIST_T; dir=:out, num_nodes=nothing, weighted=true)
4048
@assert dir [:out, :in]
4149
num_nodes = length(adj_list)
4250
num_edges = sum(length.(adj_list))
@@ -64,7 +72,7 @@ end
6472

6573
to_dense(A::AbstractSparseMatrix, x...; kws...) = to_dense(collect(A), x...; kws...)
6674

67-
function to_dense(A::ADJMAT_T, T=nothing; dir=:out, num_nodes=nothing)
75+
function to_dense(A::ADJMAT_T, T=nothing; dir=:out, num_nodes=nothing, weighted=true)
6876
@assert dir [:out, :in]
6977
T = T === nothing ? eltype(A) : T
7078
num_nodes = size(A, 1)
@@ -77,10 +85,13 @@ function to_dense(A::ADJMAT_T, T=nothing; dir=:out, num_nodes=nothing)
7785
if T != eltype(A)
7886
A = T.(A)
7987
end
88+
if !weighted
89+
A = map(x -> x > 0 ? T(1) : T(0), A)
90+
end
8091
return A, num_nodes, num_edges
8192
end
8293

83-
function to_dense(adj_list::ADJLIST_T, T=nothing; dir=:out, num_nodes=nothing)
94+
function to_dense(adj_list::ADJLIST_T, T=nothing; dir=:out, num_nodes=nothing, weighted=true)
8495
@assert dir [:out, :in]
8596
num_nodes = length(adj_list)
8697
num_edges = sum(length.(adj_list))
@@ -99,13 +110,16 @@ function to_dense(adj_list::ADJLIST_T, T=nothing; dir=:out, num_nodes=nothing)
99110
A, num_nodes, num_edges
100111
end
101112

102-
function to_dense(coo::COO_T, T=nothing; dir=:out, num_nodes=nothing)
113+
function to_dense(coo::COO_T, T=nothing; dir=:out, num_nodes=nothing, weighted=true)
103114
# `dir` will be ignored since the input `coo` is always in source -> target format.
104115
# The output will always be a adjmat in :out format (e.g. A[i,j] denotes from i to j)
105116
s, t, val = coo
106117
n = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes
107118
val = isnothing(val) ? eltype(s)(1) : val
108119
T = T === nothing ? eltype(val) : T
120+
if !weighted
121+
val = T(1)
122+
end
109123
A = fill!(similar(s, T, (n, n)), 0)
110124
v = vec(A)
111125
idxs = s .+ n .* (t .- 1)
@@ -116,7 +130,7 @@ end
116130

117131
### SPARSE #############
118132

119-
function to_sparse(A::ADJMAT_T, T=nothing; dir=:out, num_nodes=nothing)
133+
function to_sparse(A::ADJMAT_T, T=nothing; dir=:out, num_nodes=nothing, weighted=true)
120134
@assert dir [:out, :in]
121135
num_nodes = size(A, 1)
122136
@assert num_nodes == size(A, 2)
@@ -131,18 +145,25 @@ function to_sparse(A::ADJMAT_T, T=nothing; dir=:out, num_nodes=nothing)
131145
if !(A isa AbstractSparseMatrix)
132146
A = sparse(A)
133147
end
148+
if !weighted
149+
A = map(x -> x > 0 ? T(1) : T(0), A)
150+
end
134151
return A, num_nodes, num_edges
135152
end
136153

137-
function to_sparse(adj_list::ADJLIST_T, T=nothing; dir=:out, num_nodes=nothing)
154+
function to_sparse(adj_list::ADJLIST_T, T=nothing; dir=:out, num_nodes=nothing, weighted=true)
138155
coo, num_nodes, num_edges = to_coo(adj_list; dir, num_nodes)
139156
return to_sparse(coo; num_nodes)
140157
end
141158

142-
function to_sparse(coo::COO_T, T=nothing; dir=:out, num_nodes=nothing)
159+
function to_sparse(coo::COO_T, T=nothing; dir=:out, num_nodes=nothing, weighted=true)
143160
s, t, eweight = coo
144161
T = T === nothing ? (eweight === nothing ? eltype(s) : eltype(eweight)) : T
145-
eweight = eweight === nothing ? fill!(similar(s, T), 1) : eweight
162+
163+
if eweight === nothing || !weighted
164+
eweight = fill!(similar(s, T), 1)
165+
end
166+
146167
num_nodes = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes
147168
A = sparse(s, t, eweight, num_nodes, num_nodes)
148169
num_edges = nnz(A)

src/GNNGraphs/query.jl

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,20 +132,35 @@ end
132132
adjacency_list(g::GNNGraph; dir=:out) = adjacency_list(g, 1:g.num_nodes; dir)
133133

134134

135-
function Graphs.adjacency_matrix(g::GNNGraph{<:COO_T}, T::DataType=eltype(g); dir=:out)
135+
"""
136+
adjacency_matrix(g::GNNGraph, T=eltype(g); dir=:out, weighted=true)
137+
138+
Return the adjacency matrix `A` for the graph `g`.
139+
140+
If `dir=:out`, `A[i,j] > 0` denotes the presence of an edge from node `i` to node `j`.
141+
If `dir=:in` instead, `A[i,j] > 0` denotes the presence of an edge from node `j` to node `i`.
142+
143+
User may specify the eltype `T` of the returned matrix.
144+
145+
If `weighted=true`, the `A` will contain the edge weigths if any, otherwise the elements of `A` will be either 0 or 1.
146+
"""
147+
function Graphs.adjacency_matrix(g::GNNGraph{<:COO_T}, T::DataType=eltype(g); dir=:out, weighted=true)
136148
if g.graph[1] isa CuVector
137149
# TODO revisit after https://github.com/JuliaGPU/CUDA.jl/pull/1152
138-
A, n, m = to_dense(g.graph, T, num_nodes=g.num_nodes)
150+
A, n, m = to_dense(g.graph, T; num_nodes=g.num_nodes, weighted)
139151
else
140-
A, n, m = to_sparse(g.graph, T, num_nodes=g.num_nodes)
152+
A, n, m = to_sparse(g.graph, T; num_nodes=g.num_nodes, weighted)
141153
end
142154
@assert size(A) == (n, n)
143155
return dir == :out ? A : A'
144156
end
145157

146-
function Graphs.adjacency_matrix(g::GNNGraph{<:ADJMAT_T}, T::DataType=eltype(g); dir=:out)
158+
function Graphs.adjacency_matrix(g::GNNGraph{<:ADJMAT_T}, T::DataType=eltype(g); dir=:out, weighted=true)
147159
@assert dir [:in, :out]
148160
A = g.graph
161+
if !weighted
162+
A = map(>(0), A)
163+
end
149164
A = T != eltype(A) ? T.(A) : A
150165
return dir == :out ? A : A'
151166
end

src/layers/conv.jl

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ and optionally an edge weight vector.
2525
- `bias`: Add learnable bias. Default `true`.
2626
- `init`: Weights' initializer. Default `glorot_uniform`.
2727
- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`.
28-
- `use_edge_weight`. If `true`, consider the edge weights in the input graph (if available).
29-
If `add_self_loops=true` the new weights will be set to 1. Default `false`.
28+
- `use_edge_weight`: If `true`, consider the edge weights in the input graph (if available).
29+
If `add_self_loops=true` the new weights will be set to 1. Default `false`.
3030
3131
# Examples
3232
@@ -81,17 +81,13 @@ function (l::GCNConv)(g::GNNGraph{<:COO_T}, x::AbstractMatrix)
8181
return l(g, x, edge_weight)
8282
end
8383

84-
function (l::GCNConv)(g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix)
85-
edge_weight = nothing
86-
return l(g, x, edge_weight)
87-
end
88-
89-
function (l::GCNConv)(g::GNNGraph, x::AbstractMatrix{T}, edge_weight::EW) where
84+
function (l::GCNConv)(g::GNNGraph{<:COO_T}, x::AbstractMatrix{T}, edge_weight::EW) where
9085
{T, EW<:Union{Nothing,AbstractVector}}
9186

9287
if l.add_self_loops
9388
g = add_self_loops(g)
9489
if edge_weight !== nothing
90+
# TODO for ADJMAT_T the new edges are not generally at the end
9591
edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)]
9692
@assert length(edge_weight) == g.num_edges
9793
end
@@ -116,6 +112,33 @@ function (l::GCNConv)(g::GNNGraph, x::AbstractMatrix{T}, edge_weight::EW) where
116112
return l.σ.(x .+ l.bias)
117113
end
118114

115+
# TODO merge the ADJMAT_T and COO_T methods
116+
# The main problem is handling the weighted case for both.
117+
function (l::GCNConv)(g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix{T}) where T
118+
if l.add_self_loops
119+
g = add_self_loops(g)
120+
end
121+
Dout, Din = size(l.weight)
122+
if Dout < Din
123+
# multiply before convolution if it is more convenient, otherwise multiply after
124+
x = l.weight * x
125+
end
126+
d = degree(g, T; dir=:in, edge_weight=l.use_edge_weight)
127+
c = 1 ./ sqrt.(d)
128+
x = x .* c'
129+
A = adjacency_matrix(g, weighted=l.use_edge_weight)
130+
x = x * A
131+
x = x .* c'
132+
if Dout >= Din
133+
x = l.weight * x
134+
end
135+
return l.σ.(x .+ l.bias)
136+
end
137+
138+
function (l::GCNConv)(g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix, edge_weight::AbstractVector)
139+
g = GNNGraph(edge_index(g)...; g.num_nodes) # convert to COO
140+
return l(g, x, edge_weight)
141+
end
119142

120143
function Base.show(io::IO, l::GCNConv)
121144
out, in = size(l.weight)

src/msgpass.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,17 +179,27 @@ function e_mul_xj(xi, xj::AbstractArray{Tj,Nj}, e::AbstractArray{Te,Ne}) where {
179179
end
180180

181181
function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix, e)
182-
A = adjacency_matrix(g)
182+
A = adjacency_matrix(g, weighted=false)
183183
return xj * A
184184
end
185185

186+
# for weighted convolution
187+
function propagate(::typeof(e_mul_xj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix, e::AbstractVector)
188+
s, t = edge_index(g)
189+
g = GNNGraph((s, t, e); g.num_nodes)
190+
A = adjacency_matrix(g, weighted=true)
191+
return xj * A
192+
end
193+
194+
195+
186196
## avoid the fast path on gpu until we have better cuda support
187197
function propagate(::typeof(copy_xj), g::GNNGraph{<:Union{COO_T,SPARSE_T}}, ::typeof(+), xi, xj::AnyCuMatrix, e)
188198
propagate((xi,xj,e)->copy_xj(xi,xj,e), g, +, xi, xj, e)
189199
end
190200

191201
# function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(mean), xi, xj::AbstractMatrix, e)
192-
# A = adjacency_matrix(g)
202+
# A = adjacency_matrix(g, weigthed=false)
193203
# D = compute_degree(A)
194204
# return xj * A * D
195205
# end

test/GNNGraphs/query.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,15 @@
9595

9696
@testset "adjacency_matrix" begin
9797
a = sprand(5, 5, 0.5)
98+
abin = map(x -> x > 0 ? 1 : 0, a)
99+
98100
g = GNNGraph(a, graph_type=GRAPH_T)
99101
A = adjacency_matrix(g, Float32)
100-
@test a A
102+
@test A a
101103
@test eltype(A) == Float32
104+
105+
Abin = adjacency_matrix(g, Float32, weighted=false)
106+
@test Abin abin
107+
@test eltype(Abin) == Float32
102108
end
103109
end

0 commit comments

Comments
 (0)