Skip to content

Commit df56b7e

Browse files
Add DConv layer (#441)
* First `DConv` draft * CUDA friendly * Add GNNLayer * Add test * Export `DConv` * Fix * Fix test * Add docs * Add type Co-authored-by: Carlo Lucibello <[email protected]> * Add propagate but need to fix the transpose part * Fix transpose * Add spaces Co-authored-by: Carlo Lucibello <[email protected]> * Add spaces Co-authored-by: Carlo Lucibello <[email protected]> * Add spaces Co-authored-by: Carlo Lucibello <[email protected]> --------- Co-authored-by: Carlo Lucibello <[email protected]>
1 parent acf4b6a commit df56b7e

File tree

3 files changed

+88
-0
lines changed

3 files changed

+88
-0
lines changed

src/GraphNeuralNetworks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ export
6767
SGConv,
6868
TAGConv,
6969
TransformerConv,
70+
DConv,
7071

7172
# layers/heteroconv
7273
HeteroGraphConv,

src/layers/conv.jl

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2078,3 +2078,80 @@ function Base.show(io::IO, l::TransformerConv)
20782078
(in, ein), out = l.channels
20792079
print(io, "TransformerConv(($in, $ein) => $out, heads=$(l.heads))")
20802080
end
2081+
2082+
"""
2083+
DConv(ch::Pair{Int, Int}, K::Int; init = glorot_uniform, bias = true)
2084+
2085+
Diffusion convolution layer from the paper [Diffusion Convolutional Recurrent Neural Networks: Data-Driven Traffic Forecasting](https://arxiv.org/pdf/1707.01926).
2086+
2087+
# Arguments
2088+
2089+
- `ch`: Pair of input and output dimensions.
2090+
- `K`: Number of diffusion steps.
2091+
- `init`: Weights' initializer. Default `glorot_uniform`.
2092+
- `bias`: Add learnable bias. Default `true`.
2093+
2094+
# Examples
2095+
```
2096+
julia> g = GNNGraph(rand(10, 10), ndata = rand(Float32, 2, 10));
2097+
2098+
julia> dconv = DConv(2 => 4, 4)
2099+
DConv(2 => 4, K=4)
2100+
2101+
julia> y = dconv(g, g.ndata.x);
2102+
2103+
julia> size(y)
2104+
(4, 10)
2105+
```
2106+
"""
2107+
struct DConv <: GNNLayer
2108+
in::Int
2109+
out::Int
2110+
weights::AbstractArray
2111+
bias::AbstractArray
2112+
K::Int
2113+
end
2114+
2115+
@functor DConv
2116+
2117+
function DConv(ch::Pair{Int, Int}, K::Int; init = glorot_uniform, bias = true)
2118+
in, out = ch
2119+
weights = init(2, K, out, in)
2120+
b = bias ? Flux.create_bias(weights, true, out) : false
2121+
DConv(in, out, weights, b, K)
2122+
end
2123+
2124+
function (l::DConv)(g::GNNGraph, x::AbstractMatrix)
2125+
#A = adjacency_matrix(g, weighted = true)
2126+
s, t = edge_index(g)
2127+
gt = GNNGraph(t, s, get_edge_weight(g))
2128+
deg_out = degree(g; dir = :out)
2129+
deg_in = degree(g; dir = :in)
2130+
deg_out = Diagonal(deg_out)
2131+
deg_in = Diagonal(deg_in)
2132+
2133+
h = l.weights[1,1,:,:] * x .+ l.weights[2,1,:,:] * x
2134+
2135+
T0 = x
2136+
if l.K > 1
2137+
# T1_in = T0 * deg_in * A'
2138+
#T1_out = T0 * deg_out' * A
2139+
T1_out = propagate(w_mul_xj, g, +; xj = T0*deg_out')
2140+
T1_in = propagate(w_mul_xj, gt, +; xj = T0*deg_in)
2141+
h = h .+ l.weights[1,2,:,:] * T1_in .+ l.weights[2,2,:,:] * T1_out
2142+
end
2143+
for i in 2:l.K
2144+
T2_in = propagate(w_mul_xj, gt, +; xj = T1_in*deg_in)
2145+
T2_in = 2 * T2_in - T0
2146+
T2_out = propagate(w_mul_xj, g ,+; xj = T1_out*deg_out')
2147+
T2_out = 2 * T2_out - T0
2148+
h = h .+ l.weights[1,i,:,:] * T2_in .+ l.weights[2,i,:,:] * T2_out
2149+
T1_in = T2_in
2150+
T1_out = T2_out
2151+
end
2152+
return h .+ l.bias
2153+
end
2154+
2155+
function Base.show(io::IO, l::DConv)
2156+
print(io, "DConv($(l.in) => $(l.out), K=$(l.K))")
2157+
end

test/layers/conv.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,3 +349,13 @@ end
349349
outsize = (in_channel, g.num_nodes))
350350
end
351351
end
352+
353+
@testset "DConv" begin
354+
K = [1, 2, 3] # for different number of hops
355+
for k in K
356+
l = DConv(in_channel => out_channel, k)
357+
for g in test_graphs
358+
test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes))
359+
end
360+
end
361+
end

0 commit comments

Comments
 (0)