Skip to content

Commit 9a9370c

Browse files
add w_mul_xj (#114)
1 parent df5b508 commit 9a9370c

File tree

9 files changed

+100
-48
lines changed

9 files changed

+100
-48
lines changed

docs/src/api/messagepassing.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,5 @@ copy_xi
2626
copy_xj
2727
xi_dot_xj
2828
e_mul_xj
29+
w_mul_xj
2930
```

src/GNNGraphs/GNNGraphs.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ export add_nodes,
4747
rand_edge_split,
4848
remove_self_loops,
4949
remove_multi_edges,
50+
set_edge_weight,
5051
# from Flux
5152
batch,
5253
unbatch,

src/GNNGraphs/convert.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ function to_dense(A::ADJMAT_T, T=nothing; dir=:out, num_nodes=nothing, weighted=
8686
A = T.(A)
8787
end
8888
if !weighted
89-
A = map(x -> x > 0 ? T(1) : T(0), A)
89+
A = map(x -> ifelse(x > 0, T(1), T(0)), A)
9090
end
9191
return A, num_nodes, num_edges
9292
end
@@ -121,10 +121,10 @@ function to_dense(coo::COO_T, T=nothing; dir=:out, num_nodes=nothing, weighted=t
121121
val = T(1)
122122
end
123123
A = fill!(similar(s, T, (n, n)), 0)
124-
v = vec(A)
124+
v = vec(A) # vec view of A
125125
idxs = s .+ n .* (t .- 1)
126-
NNlib.scatter!(+, v, val, idxs)
127-
# A[s .+ n .* (t .- 1)] .= val # exploiting linear indexing
126+
# A[idxs] .= val # exploiting linear indexing
127+
NNlib.scatter!(+, v, val, idxs) # using scatter instead of indexing since there could be multiple edges
128128
return A, n, length(s)
129129
end
130130

@@ -146,7 +146,7 @@ function to_sparse(A::ADJMAT_T, T=nothing; dir=:out, num_nodes=nothing, weighted
146146
A = sparse(A)
147147
end
148148
if !weighted
149-
A = map(x -> x > 0 ? T(1) : T(0), A)
149+
A = map(x -> ifelse(x > 0, T(1), T(0)), A)
150150
end
151151
return A, num_nodes, num_edges
152152
end

src/GNNGraphs/transform.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ end
9797
add_edges(g::GNNGraph, s::AbstractVector, t::AbstractVector; [edata])
9898
9999
Add to graph `g` the edges with source nodes `s` and target nodes `t`.
100+
Optionally, pass the features `edata` for the new edges.
100101
"""
101102
function add_edges(g::GNNGraph{<:COO_T},
102103
snew::AbstractVector{<:Integer},
@@ -155,6 +156,20 @@ function add_nodes(g::GNNGraph{<:COO_T}, n::Integer; ndata=(;))
155156
ndata, g.edata, g.gdata)
156157
end
157158

159+
"""
160+
set_edge_weight(g::GNNGraph, w::AbstractVector)
161+
162+
Set `w` as edge weights in the returned graph.
163+
"""
164+
function set_edge_weight(g::GNNGraph, w::AbstractVector)
165+
s, t = edge_index(g)
166+
@assert length(w) == length(s)
167+
168+
return GNNGraph((s, t, w),
169+
g.num_nodes, g.num_edges, g.num_graphs,
170+
g.graph_indicator,
171+
g.ndata, g.edata, g.gdata)
172+
end
158173

159174
function SparseArrays.blockdiag(g1::GNNGraph, g2::GNNGraph)
160175
nv1, nv2 = g1.num_nodes, g2.num_nodes

src/GraphNeuralNetworks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ export
3737
copy_xi,
3838
xi_dot_xj,
3939
e_mul_xj,
40+
w_mul_xj,
4041

4142
# layers/basic
4243
GNNLayer,

src/layers/conv.jl

Lines changed: 8 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -74,16 +74,11 @@ function GCNConv(ch::Pair{Int,Int}, σ=identity;
7474
GCNConv(W, b, σ, add_self_loops, use_edge_weight)
7575
end
7676

77-
function (l::GCNConv)(g::GNNGraph{<:COO_T}, x::AbstractMatrix)
78-
# Extract edge_weight from g if available and l.edge_weight == true,
79-
# otherwise return nothing.
80-
edge_weight = GNNGraphs._get_edge_weight(g, l.use_edge_weight) # vector or nothing
81-
return l(g, x, edge_weight)
82-
end
83-
84-
function (l::GCNConv)(g::GNNGraph{<:COO_T}, x::AbstractMatrix{T}, edge_weight::EW) where
77+
function (l::GCNConv)(g::GNNGraph, x::AbstractMatrix{T}, edge_weight::EW=nothing) where
8578
{T, EW<:Union{Nothing,AbstractVector}}
8679

80+
@assert !(g isa GNNGraph{<:ADJMAT_T} && edge_weight !== nothing) "Providing external edge_weight is not yet supported for adjacency matrix graphs"
81+
8782
if l.add_self_loops
8883
g = add_self_loops(g)
8984
if edge_weight !== nothing
@@ -100,10 +95,12 @@ function (l::GCNConv)(g::GNNGraph{<:COO_T}, x::AbstractMatrix{T}, edge_weight::E
10095
d = degree(g, T; dir=:in, edge_weight)
10196
c = 1 ./ sqrt.(d)
10297
x = x .* c'
103-
if edge_weight === nothing
104-
x = propagate(copy_xj, g, +, xj=x)
105-
else
98+
if edge_weight !== nothing
10699
x = propagate(e_mul_xj, g, +, xj=x, e=edge_weight)
100+
elseif l.use_edge_weight
101+
x = propagate(w_mul_xj, g, +, xj=x)
102+
else
103+
x = propagate(copy_xj, g, +, xj=x)
107104
end
108105
x = x .* c'
109106
if Dout >= Din
@@ -112,29 +109,6 @@ function (l::GCNConv)(g::GNNGraph{<:COO_T}, x::AbstractMatrix{T}, edge_weight::E
112109
return l.σ.(x .+ l.bias)
113110
end
114111

115-
# TODO merge the ADJMAT_T and COO_T methods
116-
# The main problem is handling the weighted case for both.
117-
function (l::GCNConv)(g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix{T}) where T
118-
if l.add_self_loops
119-
g = add_self_loops(g)
120-
end
121-
Dout, Din = size(l.weight)
122-
if Dout < Din
123-
# multiply before convolution if it is more convenient, otherwise multiply after
124-
x = l.weight * x
125-
end
126-
d = degree(g, T; dir=:in, edge_weight=l.use_edge_weight)
127-
c = 1 ./ sqrt.(d)
128-
x = x .* c'
129-
A = adjacency_matrix(g, weighted=l.use_edge_weight)
130-
x = x * A
131-
x = x .* c'
132-
if Dout >= Din
133-
x = l.weight * x
134-
end
135-
return l.σ.(x .+ l.bias)
136-
end
137-
138112
function (l::GCNConv)(g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix, edge_weight::AbstractVector)
139113
g = GNNGraph(edge_index(g)...; g.num_nodes) # convert to COO
140114
return l(g, x, edge_weight)

src/msgpass.jl

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ _scatter(aggr, m::AbstractArray, t) = NNlib.scatter(aggr, m, t)
149149

150150

151151

152-
### SPECIALIZATIONS OF PROPAGATE ###
152+
### MESSAGE FUNCTIONS ###
153153
"""
154154
copy_xj(xi, xj, e) = xj
155155
"""
@@ -178,26 +178,64 @@ function e_mul_xj(xi, xj::AbstractArray{Tj,Nj}, e::AbstractArray{Te,Ne}) where {
178178
return e .* xj
179179
end
180180

181+
"""
182+
w_mul_xj(xi, xj, w) = reshape(w, (...)) .* xj
183+
184+
Similar to [`e_mul_xj`](@ref) but specialized on scalar edge feautures (weights).
185+
"""
186+
w_mul_xj(xi, xj::AbstractArray, w::Nothing) = xj # same as copy_xj if no weights
187+
188+
function w_mul_xj(xi, xj::AbstractArray{Tj,Nj}, w::AbstractVector) where {Tj, Nj}
189+
w = reshape(w, ntuple(_ -> 1, Nj-1)..., length(w))
190+
return w .* xj
191+
end
192+
193+
194+
###### PROPAGATE SPECIALIZATIONS ####################
195+
196+
## COPY_XJ
197+
181198
function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix, e)
182199
A = adjacency_matrix(g, weighted=false)
183200
return xj * A
184201
end
185202

203+
## avoid the fast path on gpu until we have better cuda support
204+
function propagate(::typeof(copy_xj), g::GNNGraph{<:Union{COO_T,SPARSE_T}}, ::typeof(+), xi, xj::AnyCuMatrix, e)
205+
propagate((xi,xj,e) -> copy_xj(xi,xj,e), g, +, xi, xj, e)
206+
end
207+
208+
## E_MUL_XJ
209+
186210
# for weighted convolution
187211
function propagate(::typeof(e_mul_xj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix, e::AbstractVector)
188-
s, t = edge_index(g)
189-
g = GNNGraph((s, t, e); g.num_nodes)
212+
g = set_edge_weight(g, e)
190213
A = adjacency_matrix(g, weighted=true)
191214
return xj * A
192215
end
193216

217+
## avoid the fast path on gpu until we have better cuda support
218+
function propagate(::typeof(e_mul_xj), g::GNNGraph{<:Union{COO_T,SPARSE_T}}, ::typeof(+), xi, xj::AnyCuMatrix, e::AbstractVector)
219+
propagate((xi,xj,e) -> e_mul_xj(xi,xj,e), g, +, xi, xj, e)
220+
end
194221

222+
## W_MUL_XJ
223+
224+
# for weighted convolution
225+
function propagate(::typeof(w_mul_xj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix, e::Nothing)
226+
A = adjacency_matrix(g, weighted=true)
227+
return xj * A
228+
end
195229

196230
## avoid the fast path on gpu until we have better cuda support
197-
function propagate(::typeof(copy_xj), g::GNNGraph{<:Union{COO_T,SPARSE_T}}, ::typeof(+), xi, xj::AnyCuMatrix, e)
198-
propagate((xi,xj,e)->copy_xj(xi,xj,e), g, +, xi, xj, e)
231+
function propagate(::typeof(w_mul_xj), g::GNNGraph{<:Union{COO_T,SPARSE_T}}, ::typeof(+), xi, xj::AnyCuMatrix, e::Nothing)
232+
propagate((xi,xj,e) -> w_mul_xj(xi,xj,e), g, +, xi, xj, e)
199233
end
200234

235+
236+
237+
238+
201239
# function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(mean), xi, xj::AbstractMatrix, e)
202240
# A = adjacency_matrix(g, weigthed=false)
203241
# D = compute_degree(A)

test/GNNGraphs/transform.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,4 +173,18 @@
173173
@test g2.num_edges < 50
174174
end
175175
end
176+
177+
@testset "set_edge_weight" begin
178+
g = rand_graph(10, 20, graph_type=GRAPH_T)
179+
w = rand(20)
180+
181+
gw = set_edge_weight(g, w)
182+
@test get_edge_weight(gw) == w
183+
184+
# now from weighted graph
185+
s, t = edge_index(g)
186+
g2 = GNNGraph(s, t, rand(20), graph_type=GRAPH_T)
187+
gw2 = set_edge_weight(g2, w)
188+
@test get_edge_weight(gw2) == w
189+
end
176190
end

test/msgpass.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,28 +85,36 @@
8585
@test spmm_copyxj_fused(g) X * Adj
8686
end
8787

88-
@testset "e_mul_xj for weighted conv" begin
88+
@testset "e_mul_xj adn w_mul_xj for weighted conv" begin
8989
n = 128
9090
A = sprand(n, n, 0.1)
9191
Adj = map(x -> x > 0 ? 1 : 0, A)
9292
X = rand(10, n)
9393

94-
g = GNNGraph(A, ndata=X, edata=reshape(A.nzval, 1, :), graph_type=GRAPH_T)
94+
g = GNNGraph(A, ndata=X, edata=A.nzval, graph_type=GRAPH_T)
9595

9696
function spmm_unfused(g)
9797
propagate(
98-
(xi, xj, e) -> e .* xj ,
98+
(xi, xj, e) -> reshape(e, 1, :) .* xj ,
9999
g, +; xj=g.ndata.x, e=g.edata.e
100100
)
101101
end
102102
function spmm_fused(g)
103103
propagate(
104104
e_mul_xj,
105-
g, +; xj=g.ndata.x, e=vec(g.edata.e)
105+
g, +; xj=g.ndata.x, e=g.edata.e
106+
)
107+
end
108+
109+
function spmm_fused2(g)
110+
propagate(
111+
w_mul_xj,
112+
g, +; xj=g.ndata.x
106113
)
107114
end
108115

109116
@test spmm_unfused(g) X * A
110117
@test spmm_fused(g) X * A
118+
@test spmm_fused2(g) X * A
111119
end
112120
end

0 commit comments

Comments
 (0)