|
1 | 1 | @doc raw"""
|
2 |
| - GCNConv(in => out, σ=identity; bias=true, init=glorot_uniform, add_self_loops=true) |
| 2 | + GCNConv(in => out, σ=identity; bias=true, init=glorot_uniform, add_self_loops=true, edge_weight=true) |
3 | 3 |
|
4 | 4 | Graph convolutional layer from paper [Semi-supervised Classification with Graph Convolutional Networks](https://arxiv.org/abs/1609.02907).
|
5 | 5 |
|
6 | 6 | Performs the operation
|
7 | 7 | ```math
|
8 |
| -\mathbf{x}'_i = \sum_{j\in N(i)} \frac{1}{c_{ij}} W \mathbf{x}_j |
| 8 | +\mathbf{x}'_i = \sum_{j\in N(i)} a_{ij} W \mathbf{x}_j |
9 | 9 | ```
|
10 |
| -where ``c_{ij} = \sqrt{|N(i)||N(j)|}``. |
| 10 | +where ``a_{ij} = 1 / \sqrt{|N(i)||N(j)|}`` is a normalization factor computed from the node degrees. |
11 | 11 |
|
12 |
| -The input to the layer is a node feature array `X` |
13 |
| -of size `(num_features, num_nodes)`. |
| 12 | +If the input graph has weighted edges and `edge_weight=true`, than ``c_{ij}`` will be computed as |
| 13 | +```math |
| 14 | +a_{ij} = \frac{e_{j\to i}}{\sqrt{\sum_{j \in N(i)} e_{j\to i}} \sqrt{\sum_{i \in N(j)} e_{i\to j}}} |
| 15 | +``` |
| 16 | +
|
| 17 | +The input to the layer is a node feature array `X` of size `(num_features, num_nodes)` |
| 18 | +and optionally an edge weight vector. |
14 | 19 |
|
15 | 20 | # Arguments
|
16 | 21 |
|
17 | 22 | - `in`: Number of input features.
|
18 | 23 | - `out`: Number of output features.
|
19 |
| -- `σ`: Activation function. |
20 |
| -- `bias`: Add learnable bias. |
21 |
| -- `init`: Weights' initializer. |
22 |
| -- `add_self_loops`: Add self loops to the graph before performing the convolution. |
| 24 | +- `σ`: Activation function. Default `identity`. |
| 25 | +- `bias`: Add learnable bias. Default `true`. |
| 26 | +- `init`: Weights' initializer. Default `glorot_uniform`. |
| 27 | +- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`. |
| 28 | +- `edge_weight`. If `true`, consider the edge weights in the input graph (if available). |
| 29 | + Not compatible with `add_self_loops=true` at the moment. Default `true`. |
23 | 30 | """
|
24 | 31 | struct GCNConv{A<:AbstractMatrix, B, F} <: GNNLayer
|
25 | 32 | weight::A
|
26 | 33 | bias::B
|
27 | 34 | σ::F
|
28 | 35 | add_self_loops::Bool
|
| 36 | + edge_weight::Bool |
29 | 37 | end
|
30 | 38 |
|
31 | 39 | @functor GCNConv
|
32 | 40 |
|
33 | 41 | function GCNConv(ch::Pair{Int,Int}, σ=identity;
|
34 |
| - init=glorot_uniform, bias::Bool=true, |
35 |
| - add_self_loops=true) |
| 42 | + init=glorot_uniform, |
| 43 | + bias::Bool=true, |
| 44 | + add_self_loops=true, |
| 45 | + edge_weight=false) |
36 | 46 | in, out = ch
|
37 | 47 | W = init(out, in)
|
38 | 48 | b = bias ? Flux.create_bias(W, true, out) : false
|
39 |
| - GCNConv(W, b, σ, add_self_loops) |
| 49 | + GCNConv(W, b, σ, add_self_loops, edge_weight) |
40 | 50 | end
|
41 | 51 |
|
42 |
| -function (l::GCNConv)(g::GNNGraph, x::AbstractMatrix{T}) where T |
| 52 | +function (l::GCNConv)(g::GNNGraph, x::AbstractMatrix) |
| 53 | + # Extract edge_weight from g if available and l.edge_weight == false, |
| 54 | + # otherwise return nothing. |
| 55 | + edge_weight = GNNGraphs._get_edge_weight(g, l.edge_weight) # vector or nothing |
| 56 | + return l(g, x, edge_weight) |
| 57 | +end |
| 58 | + |
| 59 | +function (l::GCNConv)(g::GNNGraph, x::AbstractMatrix{T}, edge_weight::EW) where |
| 60 | + {T, EW<:Union{Nothing,AbstractVector}} |
| 61 | + |
43 | 62 | if l.add_self_loops
|
| 63 | + @assert edge_weight === nothing |
44 | 64 | g = add_self_loops(g)
|
45 | 65 | end
|
46 | 66 | Dout, Din = size(l.weight)
|
47 | 67 | if Dout < Din
|
48 | 68 | x = l.weight * x
|
49 | 69 | end
|
50 | 70 | # @assert all(>(0), degree(g, T, dir=:in))
|
51 |
| - c = 1 ./ sqrt.(degree(g, T, dir=:in)) |
52 |
| - x = x .* c' |
53 |
| - x = propagate(copy_xj, g, +, xj=x) |
| 71 | + c = 1 ./ sqrt.(degree(g, T; dir=:in, edge_weight)) |
54 | 72 | x = x .* c'
|
| 73 | + if edge_weight === nothing |
| 74 | + x = propagate(copy_xj, g, +, xj=x) |
| 75 | + else |
| 76 | + x = propagate(e_mul_xj, g, +, xj=x, e=edge_weight) |
| 77 | + end |
| 78 | + x = x .* c' |
55 | 79 | if Dout >= Din
|
56 | 80 | x = l.weight * x
|
57 | 81 | end
|
|
0 commit comments