Skip to content

Commit 3484489

Browse files
add nodes/edges broadcast
1 parent e23036a commit 3484489

File tree

7 files changed

+164
-36
lines changed

7 files changed

+164
-36
lines changed

docs/src/api/utils.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@ Pages = ["utils.md"]
1414
## Docs
1515

1616

17-
### Readout Functions
17+
### Graph-wise operations
1818

1919
```@docs
20-
GraphNeuralNetworks.readout_nodes
21-
GraphNeuralNetworks.readout_edges
20+
GraphNeuralNetworks.reduce_nodes
21+
GraphNeuralNetworks.reduce_edges
2222
GraphNeuralNetworks.softmax_nodes
2323
GraphNeuralNetworks.softmax_edges
24+
GraphNeuralNetworks.broadcast_nodes
25+
GraphNeuralNetworks.broadcast_edges
2426
```
2527

2628
### NNlib

src/GraphNeuralNetworks.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ using MacroTools: @forward
1212
import LearnBase
1313
using LearnBase: getobs
1414
using NNlib, NNlibCUDA
15+
using NNlib: scatter, gather
1516
using ChainRulesCore
1617
import LightGraphs
1718
using LightGraphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree
@@ -31,8 +32,9 @@ export
3132
sprand, sparse, blockdiag,
3233

3334
# utils
34-
readout_nodes, readout_edges,
35+
reduce_nodes, reduce_edges,
3536
softmax_nodes, softmax_edges,
37+
broadcast_nodes, broadcast_edges,
3638

3739
# msgpass
3840
apply_edges, propagate,

src/layers/pool.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ struct GlobalPool{F} <: GNNLayer
3434
end
3535

3636
function (l::GlobalPool)(g::GNNGraph, x::AbstractArray)
37-
return readout_nodes(g, x, l.aggr)
37+
return reduce_nodes(g, x, l.aggr)
3838
end
3939

4040
"""

src/utils.jl

Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -132,65 +132,97 @@ function NNlib.scatter!(op, dst::AnyCuArray, src::Number, idx::AnyCuArray)
132132
end
133133

134134
"""
135-
readout_nodes(g, x, aggr)
135+
reduce_nodes(aggr, g, x)
136136
137137
For a batched graph `g`, return the graph-wise aggregation of the node
138138
features `x`. The aggregation operator `aggr` can be `+`, `mean`, `max`, or `min`.
139139
The returned array will have last dimension `g.num_graphs`.
140140
"""
141-
function readout_nodes(g, x, aggr)
141+
function reduce_nodes(aggr, g::GNNGraph, x)
142+
@assert size(x)[end] == g.num_nodes
142143
indexes = graph_indicator(g)
143144
return NNlib.scatter(aggr, x, indexes)
144145
end
145146

146147
"""
147-
readout_edges(g, e, aggr)
148+
reduce_edges(aggr, g, e)
148149
149150
For a batched graph `g`, return the graph-wise aggregation of the edge
150151
features `e`. The aggregation operator `aggr` can be `+`, `mean`, `max`, or `min`.
151152
The returned array will have last dimension `g.num_graphs`.
152153
"""
153-
function readout_edges(g, e, aggr)
154+
function reduce_edges(aggr, g::GNNGraph, e)
155+
@assert size(e)[end] == g.num_edges
154156
s, t = edge_index(g)
155157
indexes = graph_indicator(g)[s]
156158
return NNlib.scatter(aggr, e, indexes)
157159
end
158160

159161
"""
160-
softmax_nodes(g, x, aggr)
162+
softmax_nodes(g, x)
161163
162164
Graph-wise softmax of the node features `x`.
163165
"""
164-
function softmax_nodes(g, x)
165-
max_ = maximum(x; dims = ndims(x)) # TODO use graph-wise maximum
166-
num = exp.(x .- max_)
167-
den = readout_nodes(g, num, +)
168-
den = Flux.flatten(den) # reshape to matrix for convenience
166+
function softmax_nodes(g::GNNGraph, x)
167+
@assert size(x)[end] == g.num_nodes
169168
gi = graph_indicator(g)
170-
den = den[:, gi]
171-
return num ./ reshape(den, size(num))
169+
max_ = gather(scatter(max, x, gi), gi)
170+
num = exp.(x .- max_)
171+
den = reduce_nodes(+, g, num)
172+
den = gather(den, gi)
173+
return num ./ den
172174
end
173175

174176
"""
175177
softmax_edges(g, e)
176178
177179
Graph-wise softmax of the edge features `e`.
178180
"""
179-
function softmax_edges(g, e)
180-
max_ = maximum(e; dims = ndims(e)) # TODO use graph-wise maximum
181+
function softmax_edges(g::GNNGraph, e)
182+
@assert size(e)[end] == g.num_edges
183+
gi = graph_indicator(g, edges=true)
184+
max_ = gather(scatter(max, e, gi), gi)
181185
num = exp.(e .- max_)
182-
den = readout_edges(g, num, +)
183-
den = Flux.flatten(den) # reshape to matrix for convenience
184-
s, t = edge_index(g)
185-
gi = graph_indicator(g)[s]
186-
den = den[:, gi]
187-
return num ./ reshape(den, size(num))
186+
den = reduce_edges(+, g, num)
187+
den = gather(den, gi)
188+
return num ./ den
189+
end
190+
191+
"""
192+
broadcast_nodes(g, x)
193+
194+
Graph-wise broadcast array `x` of size `(*, g.num_graphs)`
195+
to size `(*, g.num_nodes)`.
196+
"""
197+
function broadcast_nodes(g::GNNGraph, x)
198+
@assert size(x)[end] == g.num_graphs
199+
gi = graph_indicator(g)
200+
return gather(x, gi)
188201
end
189202

190-
function graph_indicator(g)
203+
"""
204+
broadcast_edges(g, x)
205+
206+
Graph-wise broadcast array `x` of size `(*, g.num_graphs)`
207+
to size `(*, g.num_edges)`.
208+
"""
209+
function broadcast_edges(g::GNNGraph, x)
210+
@assert size(x)[end] == g.num_graphs
211+
gi = graph_indicator(g, edges=true)
212+
return gather(x, gi)
213+
end
214+
215+
216+
function graph_indicator(g; edges=false)
191217
if isnothing(g.graph_indicator)
192-
return ones_like(edge_index(g)[1], Int, g.num_nodes)
218+
gi = ones_like(edge_index(g)[1], Int, g.num_nodes)
193219
else
194-
return g.graph_indicator
220+
gi = g.graph_indicator
221+
end
222+
if edges
223+
s, t = edge_index(g)
224+
return gi[s]
225+
else
226+
return gi
195227
end
196-
end
228+
end

test.jl

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
using Flux, Random
2+
3+
4+
function get_grad1(m, data)
5+
gradient(Flux.params(m)) do
6+
loss(m, data)
7+
end
8+
end
9+
10+
function get_grad2(m, data)
11+
gradient(Flux.params(m)) do
12+
ps = Flux.params(m) # just creating params without using them
13+
loss(m, data)
14+
end
15+
end
16+
17+
function get_grad3(m, data)
18+
gradient(Flux.params(m)) do
19+
ps = Flux.params(m)
20+
loss(m, data) + sum(sum(p) for p in ps)
21+
end
22+
end
23+
24+
function get_grad4(m, data)
25+
ps = Flux.params(m)
26+
gradient(Flux.params(m)) do
27+
loss(m, data) + sum(sum(p) for p in ps)
28+
end
29+
end
30+
31+
function get_grad5(m, data)
32+
gradient(Flux.params(m)) do
33+
sum(Flux.params(m)[1]) + sum(Flux.params(m)[2])
34+
end
35+
end
36+
37+
function get_grad6(m, data)
38+
ps = Flux.params(m)
39+
gradient(Flux.params(m)) do
40+
sum(ps[1]) + sum(ps[2])
41+
end
42+
end
43+
44+
function get_grad7(m, data)
45+
ps = Flux.params(m)
46+
gradient(Flux.params(m)) do
47+
sum(m.weight) + sum(m.bias)
48+
end
49+
end
50+
51+
Random.seed!(17)
52+
m = Dense(3, 2);
53+
data = rand(Float32, 3, 5)
54+
loss(m, x) = sum(m(x).^2)
55+
56+
g1 = get_grad1(m, data)
57+
g2 = get_grad2(m, data)
58+
g3 = get_grad3(m, data)
59+
g4 = get_grad4(m, data)
60+
g5 = get_grad5(m, data)
61+
g6 = get_grad6(m, data)
62+
g7 = get_grad7(m, data)
63+
64+
@show g1[m.weight] # correct
65+
@show g2[m.weight] # == g1 .+ 1 wrong, should be == g1
66+
@show g3[m.weight] # == g1 .+ 2 wrong, should be == g1 .+ 1
67+
@show g4[m.weight] # == g1 .+ 1 correct
68+
@show g5[m.weight] # .== 2 wrong, should be .== 1
69+
@show g6[m.weight] # == nothing wrong, should be .== 1
70+
@show g7[m.weight] # == 1 correct

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ include("test_utils.jl")
1919

2020
tests = [
2121
"gnngraph",
22-
"utils"
22+
"utils",
2323
"msgpass",
2424
"layers/basic",
2525
"layers/conv",

test/utils.jl

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
@testset "Utils" begin
22
De, Dx = 3, 2
3-
g = Flux.batch([GNNGraph(erdos_renyi(10, 30), ndata=rand(Dx, 10), edata=rand(De, 30)) for i=1:5])
3+
g = Flux.batch([GNNGraph(erdos_renyi(10, 30),
4+
ndata=rand(Dx, 10),
5+
edata=rand(De, 30),
6+
graph_type=GRAPH_T) for i=1:5])
47
x = g.ndata.x
58
e = g.edata.e
69

7-
@testset "readout_nodes" begin
8-
r = readout_nodes(g, x, mean)
10+
@testset "reduce_nodes" begin
11+
r = reduce_nodes(mean, g, x)
912
@test size(r) == (Dx, g.num_graphs)
1013
@test r[:,2] mean(getgraph(g, 2).ndata.x, dims=2)
1114
end
1215

13-
@testset "readout_edges" begin
14-
r = readout_edges(g, e, mean)
16+
@testset "reduce_edges" begin
17+
r = reduce_edges(mean, g, e)
1518
@test size(r) == (De, g.num_graphs)
1619
@test r[:,2] mean(getgraph(g, 2).edata.e, dims=2)
1720
end
@@ -27,4 +30,23 @@
2730
@test size(r) == size(e)
2831
@test r[:,1:60] softmax(getgraph(g, 1).edata.e, dims=2)
2932
end
30-
end
33+
34+
35+
@testset "broadcast_nodes" begin
36+
z = rand(4, g.num_graphs)
37+
r = broadcast_nodes(g, z)
38+
@test size(r) == (4, g.num_nodes)
39+
@test r[:,1] z[:,1]
40+
@test r[:,10] z[:,1]
41+
@test r[:,11] z[:,2]
42+
end
43+
44+
@testset "broadcast_edges" begin
45+
z = rand(4, g.num_graphs)
46+
r = broadcast_edges(g, z)
47+
@test size(r) == (4, g.num_edges)
48+
@test r[:,1] z[:,1]
49+
@test r[:,60] z[:,1]
50+
@test r[:,61] z[:,2]
51+
end
52+
end

0 commit comments

Comments
 (0)