Skip to content

Commit 24512ee

Browse files
add softmax_edge_neighbors (#59)
* add softmax_edge_neighbors
1 parent 5ad064f commit 24512ee

File tree

6 files changed

+51
-14
lines changed

6 files changed

+51
-14
lines changed

docs/src/api/utils.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ GraphNeuralNetworks.broadcast_nodes
2525
GraphNeuralNetworks.broadcast_edges
2626
```
2727

28+
### Neighborhood operations
29+
30+
```@docs
31+
GraphNeuralNetworks.softmax_edge_neighbors
32+
```
33+
2834
### NNlib
2935

3036
Primitive functions implemented in NNlib.jl.

docs/src/messagepassing.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@ A generic message passing on graph takes the form
1414
where we refer to ``\phi`` as to the message function,
1515
and to ``\gamma_x`` and ``\gamma_e`` as to the node update and edge update function
1616
respectively. The aggregation ``\square`` is over the neighborhood ``N(i)`` of node ``i``,
17-
and it is usually set to summation ``\sum``, a max or a mean operation.
17+
and it is usually equal either to ``\sum``, to `max` or to a `mean` operation.
1818

1919
In GNN.jl, the function [`propagate`](@ref) takes care of materializing the
2020
node features on each edge, applying the message function, performing the
2121
aggregation, and returning ``\bar{\mathbf{m}}``.
2222
It is then left to the user to perform further node and edge updates,
23-
manypulating arrays of size ``D_{node} \times num_nodes`` and
24-
``D_{edge} \times num_edges``.
23+
manypulating arrays of size ``D_{node} \times num\_nodes`` and
24+
``D_{edge} \times num\_edges``.
2525

2626
As part of the [`propagate`](@ref) pipeline, we have the function
2727
[`apply_edges`](@ref). It can be independently used to materialize
@@ -34,9 +34,9 @@ and [`NNlib.scatter`](@ref) methods.
3434

3535
## Examples
3636

37-
### Basic use propagate and apply_edges
38-
37+
### Basic use of propagate and apply_edges
3938

39+
TODO
4040

4141
### Implementing a custom Graph Convolutional Layer
4242

src/GraphNeuralNetworks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ export
3535
reduce_nodes, reduce_edges,
3636
softmax_nodes, softmax_edges,
3737
broadcast_nodes, broadcast_edges,
38+
softmax_edge_neighbors,
3839

3940
# msgpass
4041
apply_edges, propagate,

src/layers/pool.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,24 +52,24 @@ Global soft attention layer from the [Gated Graph Sequence Neural
5252
Networks](https://arxiv.org/abs/1511.05493) paper
5353
5454
```math
55-
\mathbf{u}_V} = \sum_{i\in V} \alpha_i\, f_{\mathrm{feat}}(\mathbf{x}_i)
55+
\mathbf{u}_V = \sum_{i\in V} \alpha_i\, f_{feat}(\mathbf{x}_i)
5656
```
5757
58-
where the coefficients ``alpha_i`` are given by a [`softmax_nodes`](@ref)
58+
where the coefficients ``\alpha_i`` are given by a [`softmax_nodes`](@ref)
5959
operation:
6060
6161
```math
62-
\alpha_i = \frac{e^{f_{\mathrm{feat}}(\mathbf{x}_i)}}
63-
{\sum_{i'\in V} e^{f_{\mathrm{feat}}(\mathbf{x}_{i'})}}.
62+
\alpha_i = \frac{e^{f_{gate}(\mathbf{x}_i)}}
63+
{\sum_{i'\in V} e^{f_{gate}(\mathbf{x}_{i'})}}.
6464
```
6565
6666
# Arguments
6767
68-
- `fgate`: The function ``f_{\mathrm{gate}} \colon \mathbb{R}^{D_{in}} \to
69-
\mathbb{R}``. It is tipically a neural network.
68+
- `fgate`: The function ``f_{gate}: \mathbb{R}^{D_{in}} \to \mathbb{R}``.
69+
It is tipically expressed by a neural network.
7070
71-
- `ffeat`: The function ``f_{\mathrm{feat}} \colon \mathbb{R}^{D_{in}} \to
72-
\mathbb{R}^{D_{out}}``. It is tipically a neural network.
71+
- `ffeat`: The function ``f_{feat}: \mathbb{R}^{D_{in}} \to \mathbb{R}^{D_{out}}``.
72+
It is tipically expressed by a neural network.
7373
7474
# Examples
7575
@@ -88,6 +88,7 @@ g = Flux.batch([GNNGraph(random_regular_graph(10, 4),
8888
u = pool(g, g.ndata.x)
8989
9090
@assert size(u) == (chout, g.num_graphs)
91+
```
9192
"""
9293
struct GlobalAttentionPool{G,F}
9394
fgate::G

src/utils.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,25 @@ function softmax_edges(g::GNNGraph, e)
190190
return num ./ den
191191
end
192192

193+
@doc raw"""
194+
softmax_edge_neighbors(g, e)
195+
196+
Softmax over each node's neighborhood of the edge features `e`.
197+
198+
```math
199+
\mathbf{e}'_{j\to i} = \frac{e^{\mathbf{e}_{j\to i}}}
200+
{\sum_{j'\in N(i)} e^{\mathbf{e}_{j\to i}}}.
201+
```
202+
"""
203+
function softmax_edge_neighbors(g::GNNGraph, e)
204+
@assert size(e)[end] == g.num_edges
205+
s, t = edge_index(g)
206+
max_ = gather(scatter(max, e, t), t)
207+
num = exp.(e .- max_)
208+
den = gather(scatter(+, num, t), t)
209+
return num ./ den
210+
end
211+
193212
"""
194213
broadcast_nodes(g, x)
195214

test/utils.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
@test r[:,1:60] softmax(getgraph(g, 1).edata.e, dims=2)
3232
end
3333

34-
3534
@testset "broadcast_nodes" begin
3635
z = rand(4, g.num_graphs)
3736
r = broadcast_nodes(g, z)
@@ -49,4 +48,15 @@
4948
@test r[:,60] z[:,1]
5049
@test r[:,61] z[:,2]
5150
end
51+
52+
@testset "softmax_edge_neighbors" begin
53+
s = [1,2,3,4]
54+
t = [5,5,6,6]
55+
g2 = GNNGraph(s, t)
56+
e2 = randn(Float32, 3, g2.num_edges)
57+
z = softmax_edge_neighbors(g2, e2)
58+
@test size(z) == size(e2)
59+
@test z[:,1:2] NNlib.softmax(e2[:,1:2], dims=2)
60+
@test z[:,3:4] NNlib.softmax(e2[:,3:4], dims=2)
61+
end
5262
end

0 commit comments

Comments
 (0)