Skip to content

Commit dc792d5

Browse files
add softmax and readout
1 parent b2f6ebf commit dc792d5

File tree

7 files changed

+135
-24
lines changed

7 files changed

+135
-24
lines changed

docs/make.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ makedocs(;
1818
"Convolutional Layers" => "api/conv.md",
1919
"Pooling Layers" => "api/pool.md",
2020
"Message Passing" => "api/messagepassing.md",
21-
"NNlib" => "api/nnlib.md",
21+
"Utils" => "api/utils.md",
2222
],
2323
"Developer Notes" => "dev.md",
2424
],

docs/src/api/nnlib.md

Lines changed: 0 additions & 23 deletions
This file was deleted.

docs/src/api/utils.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
```@meta
2+
CurrentModule = GraphNeuralNetworks
3+
```
4+
5+
# Utility Functions
6+
7+
## Index
8+
9+
```@index
10+
Order = [:type, :function]
11+
Pages = ["utils.md"]
12+
```
13+
14+
## Docs
15+
16+
17+
### Readout Functions
18+
19+
```@docs
20+
GraphNeuralNetworks.readout_nodes
21+
GraphNeuralNetworks.readout_edges
22+
GraphNeuralNetworks.softmax_nodes
23+
GraphNeuralNetworks.softmax_edges
24+
```
25+
26+
### NNlib
27+
28+
Primitive functions implemented in NNlib.jl.
29+
30+
```@docs
31+
NNlib.gather!
32+
NNlib.gather
33+
NNlib.scatter!
34+
NNlib.scatter
35+
```

src/GraphNeuralNetworks.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ export
3030
# from SparseArrays
3131
sprand, sparse, blockdiag,
3232

33+
# utils
34+
readout_nodes, readout_edges,
35+
softmax_nodes, softmax_edges,
36+
3337
# msgpass
3438
apply_edges, propagate,
3539
copyxj,

src/utils.jl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,4 +129,68 @@ function NNlib.scatter!(op, dst::AnyCuArray, src::Number, idx::AnyCuArray)
129129
blocks = cld(max_idx, threads)
130130
kernel(args...; threads=threads, blocks=blocks)
131131
return dst
132+
end
133+
134+
"""
135+
readout_nodes(g, x, aggr)
136+
137+
For a batched graph `g`, return the graph-wise aggregation of the node
138+
features `x`. The aggregation operator `aggr` can be `+`, `mean`, `max`, or `min`.
139+
The returned array will have last dimension `g.num_graphs`.
140+
"""
141+
function readout_nodes(g, x, aggr)
142+
indexes = graph_indicator(g)
143+
return NNlib.scatter(aggr, x, indexes)
144+
end
145+
146+
"""
147+
readout_edges(g, e, aggr)
148+
149+
For a batched graph `g`, return the graph-wise aggregation of the edge
150+
features `e`. The aggregation operator `aggr` can be `+`, `mean`, `max`, or `min`.
151+
The returned array will have last dimension `g.num_graphs`.
152+
"""
153+
function readout_edges(g, e, aggr)
154+
s, t = edge_index(g)
155+
indexes = graph_indicator(g)[s]
156+
return NNlib.scatter(aggr, e, indexes)
157+
end
158+
159+
"""
160+
softmax_nodes(g, x, aggr)
161+
162+
Graph-wise softmax of the node features `x`.
163+
"""
164+
function softmax_nodes(g, x)
165+
max_ = maximum(x; dims = ndims(x)) # TODO use graph-wise maximum
166+
num = exp.(x .- max_)
167+
den = readout_nodes(g, num, +)
168+
den = Flux.flatten(den) # reshape to matrix for convenience
169+
gi = graph_indicator(g)
170+
den = den[:, gi]
171+
return num ./ reshape(den, size(num))
172+
end
173+
174+
"""
175+
softmax_edges(g, e)
176+
177+
Graph-wise softmax of the edge features `e`.
178+
"""
179+
function softmax_edges(g, e)
180+
max_ = maximum(e; dims = ndims(e)) # TODO use graph-wise maximum
181+
num = exp.(e .- max_)
182+
den = readout_edges(g, num, +)
183+
den = Flux.flatten(den) # reshape to matrix for convenience
184+
s, t = edge_index(g)
185+
gi = graph_indicator(g)[s]
186+
den = den[:, gi]
187+
return num ./ reshape(den, size(num))
188+
end
189+
190+
function graph_indicator(g)
191+
if isnothing(g.graph_indicator)
192+
return ones_like(edge_index(g)[1], Int, g.num_nodes)
193+
else
194+
return g.graph_indicator
195+
end
132196
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ include("test_utils.jl")
1919

2020
tests = [
2121
"gnngraph",
22+
"utils"
2223
"msgpass",
2324
"layers/basic",
2425
"layers/conv",

test/utils.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
@testset "Utils" begin
2+
De, Dx = 3, 2
3+
g = Flux.batch([GNNGraph(erdos_renyi(10, 30), ndata=rand(Dx, 10), edata=rand(De, 30)) for i=1:5])
4+
x = g.ndata.x
5+
e = g.edata.e
6+
7+
@testset "readout_nodes" begin
8+
r = readout_nodes(g, x, mean)
9+
@test size(r) == (Dx, g.num_graphs)
10+
@test r[:,2] mean(getgraph(g, 2).ndata.x, dims=2)
11+
end
12+
13+
@testset "readout_edges" begin
14+
r = readout_edges(g, e, mean)
15+
@test size(r) == (De, g.num_graphs)
16+
@test r[:,2] mean(getgraph(g, 2).edata.e, dims=2)
17+
end
18+
19+
@testset "softmax_nodes" begin
20+
r = softmax_nodes(g, x)
21+
@test size(r) == size(x)
22+
@test r[:,1:10] softmax(getgraph(g, 1).ndata.x, dims=2)
23+
end
24+
25+
@testset "softmax_edges" begin
26+
r = softmax_edges(g, e)
27+
@test size(r) == size(e)
28+
@test r[:,1:60] softmax(getgraph(g, 1).edata.e, dims=2)
29+
end
30+
end

0 commit comments

Comments
 (0)