Skip to content

Commit 18c4606

Browse files
authored
Creating dropout functionality in the GATConv and GATv2Conv Layers (#411)
* Adding the dropout functionalities to GAT and GATV2 Signed-off-by: achiverram28 <[email protected]> * Corrrecting dropout keyword Signed-off-by: achiverram28 <[email protected]> * Adding the test for dropout for GATConv and GATV2Conv Signed-off-by: achiverram28 <[email protected]> * Fix Signed-off-by: achiverram28 <[email protected]> * Fix in test Signed-off-by: achiverram28 <[email protected]> --------- Signed-off-by: achiverram28 <[email protected]>
1 parent 6f07d5b commit 18c4606

File tree

2 files changed

+21
-24
lines changed

2 files changed

+21
-24
lines changed

src/layers/conv.jl

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ and the attention coefficients will be calculated as
364364
- `concat`: Concatenate layer output or not. If not, layer output is averaged over the heads. Default `true`.
365365
- `negative_slope`: The parameter of LeakyReLU.Default `0.2`.
366366
- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `true`.
367-
367+
- `dropout`: Dropout probability on the normalized attention coefficient. Default `0.0`.
368368
369369
# Examples
370370
@@ -384,7 +384,7 @@ l = GATConv(in_channel => out_channel, add_self_loops = false, bias = false; hea
384384
y = l(g, x)
385385
```
386386
"""
387-
struct GATConv{DX <: Dense, DE <: Union{Dense, Nothing}, T, A <: AbstractMatrix, F, B} <:
387+
struct GATConv{DX <: Dense, DE <: Union{Dense, Nothing}, DV, T, A <: AbstractMatrix, F, B} <:
388388
GNNLayer
389389
dense_x::DX
390390
dense_e::DE
@@ -396,6 +396,7 @@ struct GATConv{DX <: Dense, DE <: Union{Dense, Nothing}, T, A <: AbstractMatrix,
396396
heads::Int
397397
concat::Bool
398398
add_self_loops::Bool
399+
dropout::DV
399400
end
400401

401402
@functor GATConv
@@ -405,7 +406,7 @@ GATConv(ch::Pair{Int, Int}, args...; kws...) = GATConv((ch[1], 0) => ch[2], args
405406

406407
function GATConv(ch::Pair{NTuple{2, Int}, Int}, σ = identity;
407408
heads::Int = 1, concat::Bool = true, negative_slope = 0.2,
408-
init = glorot_uniform, bias::Bool = true, add_self_loops = true)
409+
init = glorot_uniform, bias::Bool = true, add_self_loops = true, dropout=0.0)
409410
(in, ein), out = ch
410411
if add_self_loops
411412
@assert ein==0 "Using edge features and setting add_self_loops=true at the same time is not yet supported."
@@ -416,7 +417,7 @@ function GATConv(ch::Pair{NTuple{2, Int}, Int}, σ = identity;
416417
b = bias ? Flux.create_bias(dense_x.weight, true, concat ? out * heads : out) : false
417418
a = init(ein > 0 ? 3out : 2out, heads)
418419
negative_slope = convert(Float32, negative_slope)
419-
GATConv(dense_x, dense_e, b, a, σ, negative_slope, ch, heads, concat, add_self_loops)
420+
GATConv(dense_x, dense_e, b, a, σ, negative_slope, ch, heads, concat, add_self_loops, dropout)
420421
end
421422

422423
(l::GATConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g)))
@@ -448,6 +449,7 @@ function (l::GATConv)(g::AbstractGNNGraph, x,
448449
# a hand-written message passing
449450
m = apply_edges((xi, xj, e) -> message(l, xi, xj, e), g, Wxi, Wxj, e)
450451
α = softmax_edge_neighbors(g, m.logα)
452+
α = dropout(α, l.dropout)
451453
β = α .* m.Wxj
452454
x = aggregate_neighbors(g, +, β)
453455

@@ -518,6 +520,7 @@ and the attention coefficients will be calculated as
518520
- `concat`: Concatenate layer output or not. If not, layer output is averaged over the heads. Default `true`.
519521
- `negative_slope`: The parameter of LeakyReLU.Default `0.2`.
520522
- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `true`.
523+
- `dropout`: Dropout probability on the normalized attention coefficient. Default `0.0`.
521524
522525
# Examples
523526
```julia
@@ -540,7 +543,7 @@ e = randn(Float32, ein, length(s))
540543
y = l(g, x, e)
541544
```
542545
"""
543-
struct GATv2Conv{T, A1, A2, A3, B, C <: AbstractMatrix, F} <: GNNLayer
546+
struct GATv2Conv{T, A1, A2, A3, DV, B, C <: AbstractMatrix, F} <: GNNLayer
544547
dense_i::A1
545548
dense_j::A2
546549
dense_e::A3
@@ -552,6 +555,7 @@ struct GATv2Conv{T, A1, A2, A3, B, C <: AbstractMatrix, F} <: GNNLayer
552555
heads::Int
553556
concat::Bool
554557
add_self_loops::Bool
558+
dropout::DV
555559
end
556560

557561
@functor GATv2Conv
@@ -568,7 +572,8 @@ function GATv2Conv(ch::Pair{NTuple{2, Int}, Int},
568572
negative_slope = 0.2,
569573
init = glorot_uniform,
570574
bias::Bool = true,
571-
add_self_loops = true)
575+
add_self_loops = true,
576+
dropout=0.0)
572577
(in, ein), out = ch
573578

574579
if add_self_loops
@@ -586,7 +591,7 @@ function GATv2Conv(ch::Pair{NTuple{2, Int}, Int},
586591
a = init(out, heads)
587592
negative_slope = convert(eltype(dense_i.weight), negative_slope)
588593
GATv2Conv(dense_i, dense_j, dense_e, b, a, σ, negative_slope, ch, heads, concat,
589-
add_self_loops)
594+
add_self_loops, dropout)
590595
end
591596

592597
(l::GATv2Conv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g)))
@@ -611,6 +616,7 @@ function (l::GATv2Conv)(g::AbstractGNNGraph, x,
611616

612617
m = apply_edges((xi, xj, e) -> message(l, xi, xj, e), g, Wxi, Wxj, e)
613618
α = softmax_edge_neighbors(g, m.logα)
619+
α = dropout(α, l.dropout)
614620
β = α .* m.Wxj
615621
x = aggregate_neighbors(g, +, β)
616622

test/layers/conv.jl

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -107,21 +107,21 @@ end
107107

108108
@testset "GATConv" begin
109109
for heads in (1, 2), concat in (true, false)
110-
l = GATConv(in_channel => out_channel; heads, concat)
110+
l = GATConv(in_channel => out_channel; heads, concat, dropout=0)
111111
for g in test_graphs
112112
test_layer(l, g, rtol = RTOL_LOW,
113-
exclude_grad_fields = [:negative_slope],
113+
exclude_grad_fields = [:negative_slope, :dropout],
114114
outsize = (concat ? heads * out_channel : out_channel,
115115
g.num_nodes))
116116
end
117117
end
118118

119119
@testset "edge features" begin
120120
ein = 3
121-
l = GATConv((in_channel, ein) => out_channel, add_self_loops = false)
121+
l = GATConv((in_channel, ein) => out_channel, add_self_loops = false, dropout=0)
122122
g = GNNGraph(g1, edata = rand(T, ein, g1.num_edges))
123123
test_layer(l, g, rtol = RTOL_LOW,
124-
exclude_grad_fields = [:negative_slope],
124+
exclude_grad_fields = [:negative_slope, :dropout],
125125
outsize = (out_channel, g.num_nodes))
126126
end
127127

@@ -137,21 +137,21 @@ end
137137

138138
@testset "GATv2Conv" begin
139139
for heads in (1, 2), concat in (true, false)
140-
l = GATv2Conv(in_channel => out_channel, tanh; heads, concat)
140+
l = GATv2Conv(in_channel => out_channel, tanh; heads, concat, dropout=0)
141141
for g in test_graphs
142142
test_layer(l, g, rtol = RTOL_LOW, atol=ATOL_LOW,
143-
exclude_grad_fields = [:negative_slope],
143+
exclude_grad_fields = [:negative_slope, :dropout],
144144
outsize = (concat ? heads * out_channel : out_channel,
145145
g.num_nodes))
146146
end
147147
end
148148

149149
@testset "edge features" begin
150150
ein = 3
151-
l = GATv2Conv((in_channel, ein) => out_channel, add_self_loops = false)
151+
l = GATv2Conv((in_channel, ein) => out_channel, add_self_loops = false, dropout=0)
152152
g = GNNGraph(g1, edata = rand(T, ein, g1.num_edges))
153153
test_layer(l, g, rtol = RTOL_LOW, atol=ATOL_LOW,
154-
exclude_grad_fields = [:negative_slope],
154+
exclude_grad_fields = [:negative_slope, :dropout],
155155
outsize = (out_channel, g.num_nodes))
156156
end
157157

@@ -163,15 +163,6 @@ end
163163
l = GATv2Conv((2, 4) => 3, add_self_loops = false, bias = false)
164164
@test length(Flux.params(l)) == 4
165165
end
166-
167-
@testset "edge features" begin
168-
ein = 3
169-
l = GATv2Conv((in_channel, ein) => out_channel, add_self_loops = false)
170-
g = GNNGraph(g1, edata = rand(T, ein, g1.num_edges))
171-
test_layer(l, g, rtol = RTOL_LOW, atol=ATOL_LOW,
172-
exclude_grad_fields = [:negative_slope],
173-
outsize = (out_channel, g.num_nodes))
174-
end
175166
end
176167

177168
@testset "GatedGraphConv" begin

0 commit comments

Comments
 (0)