Skip to content

Commit d8a6f5a

Browse files
GATv2 fix with edge features + doc improves (#131)
* docs * cleaunp
1 parent 00909d4 commit d8a6f5a

File tree

4 files changed

+44
-24
lines changed

4 files changed

+44
-24
lines changed

docs/src/api/conv.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@ CurrentModule = GraphNeuralNetworks
44

55
# Convolutional Layers
66

7-
Many different types of graphs convolutional layers have been proposed in the literature.
8-
Choosing the right layer for your application can be a matter of trial and error.
9-
Some of the most commonly used layers are the [`GCNConv`](@ref) and the [`GATv2Conv`](@ref) layers. Multiple graph convolutional layers are stacked to create a graph neural network model
7+
Many different types of graphs convolutional layers have been proposed in the literature. Choosing the right layer for your application can bould involve a lot of exploration.
8+
Some of the most commonly used layers are the [`GCNConv`](@ref) and the [`GATv2Conv`](@ref). Multiple graph convolutional layers are typically stacked together to create a graph neural network model
109
(see [`GNNChain`](@ref)).
1110

1211
The table below lists all graph convolutional layers implemented in the *GraphNeuralNetworks.jl*. It also highlights the presence of some additional capabilities with respect to basic message passing:

src/GNNGraphs/transform.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,14 @@ function add_self_loops(g::GNNGraph{<:ADJMAT_T})
3838
g.ndata, g.edata, g.gdata)
3939
end
4040

41+
"""
42+
remove_self_loops(g::GNNGraph)
43+
44+
Return a graph constructed from `g` where self-loops (edges from a node to itself)
45+
are removed.
4146
47+
See also [`add_self_loops`](@ref) and [`remove_multi_edges`](@ref).
48+
"""
4249
function remove_self_loops(g::GNNGraph{<:COO_T})
4350
s, t = edge_index(g)
4451
w = get_edge_weight(g)

src/layers/conv.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ with ``z_i`` a normalization factor.
269269
In case `ein > 0` is given, edge features of dimension `ein` will be expected in the forward pass
270270
and the attention coefficients will be calculated as
271271
```
272-
\alpha_{ij} = \frac{1}{z_i} \exp(\mathbf{a}^T LeakyReLU([W_3 \mathbf{e}_{j\to i}; W_2 \mathbf{x}_i; W_1 \mathbf{x}_j]))
272+
\alpha_{ij} = \frac{1}{z_i} \exp(LeakyReLU(\mathbf{a}^T [W_e \mathbf{e}_{j\to i}; W \mathbf{x}_i; W \mathbf{x}_j]))
273273
````
274274
275275
# Arguments
@@ -389,9 +389,9 @@ with ``z_i`` a normalization factor.
389389
390390
In case `ein > 0` is given, edge features of dimension `ein` will be expected in the forward pass
391391
and the attention coefficients will be calculated as
392+
```math
393+
\alpha_{ij} = \frac{1}{z_i} \exp(\mathbf{a}^T LeakyReLU([W_3 \mathbf{e}_{j\to i}; W_2 \mathbf{x}_i; W_1 \mathbf{x}_j])).
392394
```
393-
\alpha_{ij} = \frac{1}{z_i} \exp(\mathbf{a}^T LeakyReLU([W_3 \mathbf{e}_{j\to i}; W_2 \mathbf{x}_i; W_1 \mathbf{x}_j]))
394-
````
395395
396396
# Arguments
397397
@@ -420,7 +420,7 @@ struct GATv2Conv{T, A1, A2, A3, B, C<:AbstractMatrix} <: GNNLayer
420420
end
421421

422422
@functor GATv2Conv
423-
Flux.trainable(l::GATv2Conv) = (l.dense_i, l.dense_j, l.dense_j, l.bias, l.a)
423+
Flux.trainable(l::GATv2Conv) = (l.dense_i, l.dense_j, l.dense_e, l.bias, l.a)
424424

425425
GATv2Conv(ch::Pair{Int,Int}, args...; kws...) = GATv2Conv((ch[1], 0) => ch[2], args...; kws...)
426426

@@ -509,7 +509,7 @@ Gated graph convolution layer from [Gated Graph Sequence Neural Networks](https:
509509
510510
Implements the recursion
511511
```math
512-
\mathbf{h}^{(0)}_i = [\mathbf{x}_i || \mathbf{0}] \\
512+
\mathbf{h}^{(0)}_i = [\mathbf{x}_i; \mathbf{0}] \\
513513
\mathbf{h}^{(l)}_i = GRU(\mathbf{h}^{(l-1)}_i, \square_{j \in N(i)} W \mathbf{h}^{(l-1)}_j)
514514
```
515515
@@ -572,14 +572,14 @@ Edge convolutional layer from paper [Dynamic Graph CNN for Learning on Point Clo
572572
573573
Performs the operation
574574
```math
575-
\mathbf{x}_i' = \square_{j \in N(i)} nn(\mathbf{x}_i || \mathbf{x}_j - \mathbf{x}_i)
575+
\mathbf{x}_i' = \square_{j \in N(i)}\, nn([\mathbf{x}_i; \mathbf{x}_j - \mathbf{x}_i])
576576
```
577577
578578
where `nn` generally denotes a learnable function, e.g. a linear layer or a multi-layer perceptron.
579579
580580
# Arguments
581581
582-
- `nn`: A (possibly learnable) function acting on edge features.
582+
- `nn`: A (possibly learnable) function.
583583
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
584584
"""
585585
struct EdgeConv <: GNNLayer
@@ -946,19 +946,19 @@ end
946946
Attention-based Graph Neural Network layer from paper [Attention-based
947947
Graph Neural Network for Semi-Supervised Learning](https://arxiv.org/abs/1803.03735).
948948
949-
THe forward pass is given by
949+
The forward pass is given by
950950
```math
951-
\mathbf{x}_i' = \sum_{j \in {N(i) \cup \{i\}} \alpha_{ij} W \mathbf{x}_j
951+
\mathbf{x}_i' = \sum_{j \in {N(i) \cup \{i\}}} \alpha_{ij} W \mathbf{x}_j
952952
```
953953
where the attention coefficients ``\alpha_{ij}`` are given by
954954
```math
955955
\alpha_{ij} =\frac{e^{\beta \cos(\mathbf{x}_i, \mathbf{x}_j)}}
956-
{\sum_{j'}e^{\beta \cos(\mathbf{x}_i, \mathbf{x}_j'}}
956+
{\sum_{j'}e^{\beta \cos(\mathbf{x}_i, \mathbf{x}_{j'})}}
957957
```
958958
with the cosine distance defined by
959959
```math
960960
\cos(\mathbf{x}_i, \mathbf{x}_j) =
961-
\mathbf{x}_i \cdot \mathbf{x}_j / \lVert\mathbf{x}_i\rVert \lVert\mathbf{x}_j\rVert``
961+
\frac{\mathbf{x}_i \cdot \mathbf{x}_j}{\lVert\mathbf{x}_i\rVert \lVert\mathbf{x}_j\rVert}
962962
```
963963
and ``\beta`` a trainable parameter.
964964

test/layers/conv.jl

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -103,18 +103,21 @@
103103
end
104104
end
105105

106-
@testset "bias=false" begin
107-
@test length(Flux.params(GATConv(2=>3))) == 3
108-
@test length(Flux.params(GATConv(2=>3, bias=false))) == 2
109-
end
110-
111-
112106
@testset "edge features" begin
113107
ein = 3
114108
l = GATConv((in_channel, ein) => out_channel, add_self_loops=false)
115109
g = GNNGraph(g1, edata=rand(T, ein, g1.num_edges))
116110
test_layer(l, g, rtol=1e-3, outsize=(out_channel, g.num_nodes))
117-
end
111+
end
112+
113+
@testset "num params" begin
114+
l = GATConv(2 => 3, add_self_loops=false)
115+
@test length(Flux.params(l)) == 3
116+
l = GATConv((2,4) => 3, add_self_loops=false)
117+
@test length(Flux.params(l)) == 4
118+
l = GATConv((2,4) => 3, add_self_loops=false, bias=false)
119+
@test length(Flux.params(l)) == 3
120+
end
118121
end
119122

120123
@testset "GATv2Conv" begin
@@ -127,9 +130,20 @@
127130
end
128131
end
129132

130-
@testset "bias=false" begin
131-
@test length(Flux.params(GATv2Conv(2=>3))) == 5
132-
@test length(Flux.params(GATv2Conv(2=>3, bias=false))) == 3
133+
@testset "edge features" begin
134+
ein = 3
135+
l = GATv2Conv((in_channel, ein) => out_channel, add_self_loops=false)
136+
g = GNNGraph(g1, edata=rand(T, ein, g1.num_edges))
137+
test_layer(l, g, rtol=1e-3, outsize=(out_channel, g.num_nodes))
138+
end
139+
140+
@testset "num params" begin
141+
l = GATv2Conv(2 => 3, add_self_loops=false)
142+
@test length(Flux.params(l)) == 5
143+
l = GATv2Conv((2,4) => 3, add_self_loops=false)
144+
@test length(Flux.params(l)) == 6
145+
l = GATv2Conv((2,4) => 3, add_self_loops=false, bias=false)
146+
@test length(Flux.params(l)) == 4
133147
end
134148

135149
@testset "edge features" begin

0 commit comments

Comments
 (0)