Skip to content

Commit 867cc10

Browse files
Merge pull request #48 from CarloLucibello/cl/redesign
add nodes/edges softmax and readout
2 parents b2f6ebf + d4db332 commit 867cc10

File tree

9 files changed

+207
-42
lines changed

9 files changed

+207
-42
lines changed

docs/make.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ makedocs(;
1818
"Convolutional Layers" => "api/conv.md",
1919
"Pooling Layers" => "api/pool.md",
2020
"Message Passing" => "api/messagepassing.md",
21-
"NNlib" => "api/nnlib.md",
21+
"Utils" => "api/utils.md",
2222
],
2323
"Developer Notes" => "dev.md",
2424
],

docs/src/api/nnlib.md

Lines changed: 0 additions & 23 deletions
This file was deleted.

docs/src/api/utils.md

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
```@meta
2+
CurrentModule = GraphNeuralNetworks
3+
```
4+
5+
# Utility Functions
6+
7+
## Index
8+
9+
```@index
10+
Order = [:type, :function]
11+
Pages = ["utils.md"]
12+
```
13+
14+
## Docs
15+
16+
17+
### Graph-wise operations
18+
19+
```@docs
20+
GraphNeuralNetworks.reduce_nodes
21+
GraphNeuralNetworks.reduce_edges
22+
GraphNeuralNetworks.softmax_nodes
23+
GraphNeuralNetworks.softmax_edges
24+
GraphNeuralNetworks.broadcast_nodes
25+
GraphNeuralNetworks.broadcast_edges
26+
```
27+
28+
### NNlib
29+
30+
Primitive functions implemented in NNlib.jl.
31+
32+
```@docs
33+
NNlib.gather!
34+
NNlib.gather
35+
NNlib.scatter!
36+
NNlib.scatter
37+
```

src/GraphNeuralNetworks.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ using MacroTools: @forward
1212
import LearnBase
1313
using LearnBase: getobs
1414
using NNlib, NNlibCUDA
15+
using NNlib: scatter, gather
1516
using ChainRulesCore
1617
import LightGraphs
1718
using LightGraphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree
@@ -30,6 +31,11 @@ export
3031
# from SparseArrays
3132
sprand, sparse, blockdiag,
3233

34+
# utils
35+
reduce_nodes, reduce_edges,
36+
softmax_nodes, softmax_edges,
37+
broadcast_nodes, broadcast_edges,
38+
3339
# msgpass
3440
apply_edges, propagate,
3541
copyxj,

src/gnngraph.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -541,25 +541,24 @@ function getgraph(g::GNNGraph, i::AbstractVector{Int}; nmap=false)
541541
graphmap = Dict(i => inew for (inew, i) in enumerate(i))
542542
graph_indicator = [graphmap[i] for i in g.graph_indicator[node_mask]]
543543

544+
s, t = edge_index(g)
545+
w = edge_weight(g)
546+
edge_mask = s .∈ Ref(nodes)
547+
544548
if g.graph isa COO_T
545-
s, t = edge_index(g)
546-
w = edge_weight(g)
547-
edge_mask = s .∈ Ref(nodes)
548549
s = [nodemap[i] for i in s[edge_mask]]
549550
t = [nodemap[i] for i in t[edge_mask]]
550551
w = isnothing(w) ? nothing : w[edge_mask]
551552
graph = (s, t, w)
552-
num_edges = length(s)
553-
edata = getobs(g.edata, edge_mask)
554553
elseif g.graph isa ADJMAT_T
555554
graph = g.graph[nodes, nodes]
556-
num_edges = count(>=(0), graph)
557-
@assert g.edata == (;) # TODO
558-
edata = (;)
559555
end
556+
560557
ndata = getobs(g.ndata, node_mask)
558+
edata = getobs(g.edata, edge_mask)
561559
gdata = getobs(g.gdata, i)
562-
560+
561+
num_edges = sum(edge_mask)
563562
num_nodes = length(graph_indicator)
564563
num_graphs = length(i)
565564

src/layers/pool.jl

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ where ``V`` is the set of nodes of the input graph and
1414
the type of aggregation represented by ``\square`` is selected by the `aggr` argument.
1515
Commonly used aggregations are `mean`, `max`, and `+`.
1616
17+
See also [`reduce_nodes`](@ref).
18+
19+
# Examples
1720
```julia
1821
using Flux, GraphNeuralNetworks, LightGraphs
1922
@@ -33,14 +36,8 @@ struct GlobalPool{F} <: GNNLayer
3336
aggr::F
3437
end
3538

36-
function (l::GlobalPool)(g::GNNGraph, X::AbstractArray)
37-
if isnothing(g.graph_indicator)
38-
# assume only one graph
39-
indexes = fill!(similar(X, Int, g.num_nodes), 1)
40-
else
41-
indexes = g.graph_indicator
42-
end
43-
return NNlib.scatter(l.aggr, X, indexes)
39+
function (l::GlobalPool)(g::GNNGraph, x::AbstractArray)
40+
return reduce_nodes(l.aggr, g, x)
4441
end
4542

4643
"""

src/utils.jl

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,4 +129,100 @@ function NNlib.scatter!(op, dst::AnyCuArray, src::Number, idx::AnyCuArray)
129129
blocks = cld(max_idx, threads)
130130
kernel(args...; threads=threads, blocks=blocks)
131131
return dst
132-
end
132+
end
133+
134+
"""
135+
reduce_nodes(aggr, g, x)
136+
137+
For a batched graph `g`, return the graph-wise aggregation of the node
138+
features `x`. The aggregation operator `aggr` can be `+`, `mean`, `max`, or `min`.
139+
The returned array will have last dimension `g.num_graphs`.
140+
"""
141+
function reduce_nodes(aggr, g::GNNGraph, x)
142+
@assert size(x)[end] == g.num_nodes
143+
indexes = graph_indicator(g)
144+
return NNlib.scatter(aggr, x, indexes)
145+
end
146+
147+
"""
148+
reduce_edges(aggr, g, e)
149+
150+
For a batched graph `g`, return the graph-wise aggregation of the edge
151+
features `e`. The aggregation operator `aggr` can be `+`, `mean`, `max`, or `min`.
152+
The returned array will have last dimension `g.num_graphs`.
153+
"""
154+
function reduce_edges(aggr, g::GNNGraph, e)
155+
@assert size(e)[end] == g.num_edges
156+
s, t = edge_index(g)
157+
indexes = graph_indicator(g)[s]
158+
return NNlib.scatter(aggr, e, indexes)
159+
end
160+
161+
"""
162+
softmax_nodes(g, x)
163+
164+
Graph-wise softmax of the node features `x`.
165+
"""
166+
function softmax_nodes(g::GNNGraph, x)
167+
@assert size(x)[end] == g.num_nodes
168+
gi = graph_indicator(g)
169+
max_ = gather(scatter(max, x, gi), gi)
170+
num = exp.(x .- max_)
171+
den = reduce_nodes(+, g, num)
172+
den = gather(den, gi)
173+
return num ./ den
174+
end
175+
176+
"""
177+
softmax_edges(g, e)
178+
179+
Graph-wise softmax of the edge features `e`.
180+
"""
181+
function softmax_edges(g::GNNGraph, e)
182+
@assert size(e)[end] == g.num_edges
183+
gi = graph_indicator(g, edges=true)
184+
max_ = gather(scatter(max, e, gi), gi)
185+
num = exp.(e .- max_)
186+
den = reduce_edges(+, g, num)
187+
den = gather(den, gi)
188+
return num ./ den
189+
end
190+
191+
"""
192+
broadcast_nodes(g, x)
193+
194+
Graph-wise broadcast array `x` of size `(*, g.num_graphs)`
195+
to size `(*, g.num_nodes)`.
196+
"""
197+
function broadcast_nodes(g::GNNGraph, x)
198+
@assert size(x)[end] == g.num_graphs
199+
gi = graph_indicator(g)
200+
return gather(x, gi)
201+
end
202+
203+
"""
204+
broadcast_edges(g, x)
205+
206+
Graph-wise broadcast array `x` of size `(*, g.num_graphs)`
207+
to size `(*, g.num_edges)`.
208+
"""
209+
function broadcast_edges(g::GNNGraph, x)
210+
@assert size(x)[end] == g.num_graphs
211+
gi = graph_indicator(g, edges=true)
212+
return gather(x, gi)
213+
end
214+
215+
216+
function graph_indicator(g; edges=false)
217+
if isnothing(g.graph_indicator)
218+
gi = ones_like(edge_index(g)[1], Int, g.num_nodes)
219+
else
220+
gi = g.graph_indicator
221+
end
222+
if edges
223+
s, t = edge_index(g)
224+
return gi[s]
225+
else
226+
return gi
227+
end
228+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ include("test_utils.jl")
1919

2020
tests = [
2121
"gnngraph",
22+
"utils",
2223
"msgpass",
2324
"layers/basic",
2425
"layers/conv",

test/utils.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
@testset "Utils" begin
2+
De, Dx = 3, 2
3+
g = Flux.batch([GNNGraph(erdos_renyi(10, 30),
4+
ndata=rand(Dx, 10),
5+
edata=rand(De, 30),
6+
graph_type=GRAPH_T) for i=1:5])
7+
x = g.ndata.x
8+
e = g.edata.e
9+
10+
@testset "reduce_nodes" begin
11+
r = reduce_nodes(mean, g, x)
12+
@test size(r) == (Dx, g.num_graphs)
13+
@test r[:,2] mean(getgraph(g, 2).ndata.x, dims=2)
14+
end
15+
16+
@testset "reduce_edges" begin
17+
r = reduce_edges(mean, g, e)
18+
@test size(r) == (De, g.num_graphs)
19+
@test r[:,2] mean(getgraph(g, 2).edata.e, dims=2)
20+
end
21+
22+
@testset "softmax_nodes" begin
23+
r = softmax_nodes(g, x)
24+
@test size(r) == size(x)
25+
@test r[:,1:10] softmax(getgraph(g, 1).ndata.x, dims=2)
26+
end
27+
28+
@testset "softmax_edges" begin
29+
r = softmax_edges(g, e)
30+
@test size(r) == size(e)
31+
@test r[:,1:60] softmax(getgraph(g, 1).edata.e, dims=2)
32+
end
33+
34+
35+
@testset "broadcast_nodes" begin
36+
z = rand(4, g.num_graphs)
37+
r = broadcast_nodes(g, z)
38+
@test size(r) == (4, g.num_nodes)
39+
@test r[:,1] z[:,1]
40+
@test r[:,10] z[:,1]
41+
@test r[:,11] z[:,2]
42+
end
43+
44+
@testset "broadcast_edges" begin
45+
z = rand(4, g.num_graphs)
46+
r = broadcast_edges(g, z)
47+
@test size(r) == (4, g.num_edges)
48+
@test r[:,1] z[:,1]
49+
@test r[:,60] z[:,1]
50+
@test r[:,61] z[:,2]
51+
end
52+
end

0 commit comments

Comments
 (0)