Skip to content

Commit d5e55eb

Browse files
generalize propagation to tuple and namedtuple
1 parent 535d866 commit d5e55eb

File tree

4 files changed

+71
-29
lines changed

4 files changed

+71
-29
lines changed

src/gnngraph.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ from the LightGraphs' graph library can be used on it.
5959
6060
# Usage.
6161
62-
```
62+
```julia
6363
using Flux, GraphNeuralNetworks
6464
6565
# Construct from adjacency list representation

src/layers/conv.jl

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ end
196196

197197

198198
@doc raw"""
199-
GATConv(in => out;
199+
GATConv(in => out, , σ=identity;
200200
heads=1,
201201
concat=true,
202202
init=glorot_uniform
@@ -228,6 +228,7 @@ struct GATConv{T, A<:AbstractMatrix{T}, B} <: GNNLayer
228228
weight::A
229229
bias::B
230230
a::A
231+
σ
231232
negative_slope::T
232233
channel::Pair{Int, Int}
233234
heads::Int
@@ -237,44 +238,43 @@ end
237238
@functor GATConv
238239
Flux.trainable(l::GATConv) = (l.weight, l.bias, l.a)
239240

240-
function GATConv(ch::Pair{Int,Int};
241+
function GATConv(ch::Pair{Int,Int}, σ=identity;
241242
heads::Int=1, concat::Bool=true, negative_slope=0.2f0,
242243
init=glorot_uniform, bias::Bool=true)
243244
in, out = ch
244245
W = init(out*heads, in)
245246
b = Flux.create_bias(W, bias, out*heads)
246247
a = init(2*out, heads)
247-
GATConv(W, b, a, negative_slope, ch, heads, concat)
248+
GATConv(W, b, a, σ, negative_slope, ch, heads, concat)
248249
end
249250

250-
function (gat::GATConv)(g::GNNGraph, X::AbstractMatrix)
251-
check_num_nodes(g, X)
251+
function compute_message(l::GATConv, Wxi, Wxj)
252+
aWW = sum(l.a .* cat(Wxi, Wxj, dims=1), dims=1) # 1 × nheads × nedges
253+
α = exp.(leakyrelu.(aWW, l.negative_slope))
254+
return= α, m = α .* Wxj)
255+
end
256+
257+
update_node(l::GATConv, d̄, x) =.m ./.α
258+
259+
function (l::GATConv)(g::GNNGraph, x::AbstractMatrix)
260+
check_num_nodes(g, x)
252261
g = add_self_loops(g)
253-
chin, chout = gat.channel
254-
heads = gat.heads
262+
chin, chout = l.channel
263+
heads = l.heads
255264

256-
source, target = edge_index(g)
257-
Wx = gat.weight*X
265+
Wx = l.weight * x
258266
Wx = reshape(Wx, chout, heads, :) # chout × nheads × nnodes
259-
Wxi = NNlib.gather(Wx, target) # chout × nheads × nedges
260-
Wxj = NNlib.gather(Wx, source)
261-
262-
# Edge Message
263-
# Computing softmax. TODO make it numerically stable
264-
aWW = sum(gat.a .* cat(Wxi, Wxj, dims=1), dims=1) # 1 × nheads × nedges
265-
α = exp.(leakyrelu.(aWW, gat.negative_slope))
266-
= NNlib.scatter(+, α .* Wxj, target) # chout × nheads × nnodes
267-
ᾱ = NNlib.scatter(+, α, target) # 1 × nheads × nnodes
268267

269-
# Node update
270-
b = reshape(gat.bias, chout, heads)
271-
X =./ ᾱ .+ b # chout × nheads × nnodes
272-
if !gat.concat
273-
X = sum(X, dims=2)
268+
x, _ = propagate(l, g, +, Wx) ## chout × nheads × nnodes
269+
270+
b = reshape(l.bias, chout, heads)
271+
x = l.σ.(x .+ b)
272+
if !l.concat
273+
x = sum(x, dims=2)
274274
end
275275

276276
# We finally return a matrix
277-
return reshape(X, :, size(X, 3))
277+
return reshape(x, :, size(x, 3))
278278
end
279279

280280

src/msgpass.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ this method in the forward pass:
2424
2525
# Usage example
2626
27-
```
27+
```julia
2828
using GraphNeuralNetworks, Flux
2929
3030
struct GNNConv <: GNNLayer
@@ -101,7 +101,9 @@ function compute_message end
101101
@inline compute_message(l, x_i, x_j, e_ij) = compute_message(l, x_i, x_j)
102102
@inline compute_message(l, x_i, x_j) = x_j
103103

104-
_gather(x, i) = NNlib.gather(x, i)
104+
_gather(x::NamedTuple, i) = map(x -> _gather(x, i), x)
105+
_gather(x::Tuple, i) = map(x -> _gather(x, i), x)
106+
_gather(x::AbstractArray, i) = NNlib.gather(x, i)
105107
_gather(x::Nothing, i) = nothing
106108

107109
function compute_batch_message(l, g, x, e)
@@ -114,9 +116,14 @@ end
114116

115117
## Step 2
116118

119+
_scatter(aggr, e::NamedTuple, t) = map(e -> _scatter(aggr, e, t), e)
120+
_scatter(aggr, e::Tuple, t) = map(e -> _scatter(aggr, e, t), e)
121+
_scatter(aggr, e::AbstractArray, t) = NNlib.scatter(aggr, e, t)
122+
_scatter(aggr, e::Nothing, t) = nothing
123+
117124
function aggregate_neighbors(l, g, aggr, e)
118125
s, t = edge_index(g)
119-
NNlib.scatter(aggr, e, t)
126+
_scatter(aggr, e, t)
120127
end
121128

122129
aggregate_neighbors(l, g, aggr::Nothing, e) = nothing

test/msgpass.jl

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import GraphNeuralNetworks: compute_message, update_node, update_edge, propagate
2+
13
@testset "message passing" begin
24
in_channel = 10
35
out_channel = 5
@@ -113,9 +115,9 @@
113115
GraphNeuralNetworks.compute_message(l::NewLayerW{GRAPH_T}, x_i, x_j, e_ij) = l.weight * x_j
114116
GraphNeuralNetworks.update_node(l::NewLayerW{GRAPH_T}, m, x) = l.weight * x + m
115117

116-
l = NewLayerW(in_channel, out_channel)
117118
(l::NewLayerW{GRAPH_T})(g) = GraphNeuralNetworks.propagate(l, g, +)
118119

120+
l = NewLayerW(in_channel, out_channel)
119121
g = GNNGraph(adj, ndata=X, edata=E, gdata=U, graph_type=GRAPH_T)
120122
g_ = l(g)
121123

@@ -124,4 +126,37 @@
124126
@test edge_features(g_) === E
125127
@test graph_features(g_) === U
126128
end
129+
130+
@testset "NamedTuples" begin
131+
struct NewLayerNT{G}
132+
W
133+
end
134+
135+
NewLayerNT(in, out) = NewLayerW{GRAPH_T}(randn(T, out, in))
136+
137+
function compute_message(l::NewLayerW{GRAPH_T}, di, dj, dij)
138+
a = l.W * (di.x .+ dj.x) + dij.e
139+
b = l.W * di.x
140+
return (; a, b)
141+
end
142+
function update_node(l::NewLayerW{GRAPH_T}, m, x)
143+
return=l.W * x + m.a + m.b, β=m)
144+
end
145+
function update_edge(l::NewLayerW{GRAPH_T}, m, e)
146+
return m.a
147+
end
148+
149+
function (::NewLayerNT)(l, g, x, e)
150+
x, e = propagate(l, g, (; x), (; e))
151+
return x.α .+ x.β.a, e
152+
end
153+
154+
155+
l = NewLayerNT(in_channel, out_channel)
156+
g = GNNGraph(adj, graph_type=GRAPH_T)
157+
X′, E′ = l(g, X, E)
158+
159+
@test size(X′) == (out_channel, num_V)
160+
@test size(E′) == (out_channel, num_E)
161+
end
127162
end

0 commit comments

Comments
 (0)