@@ -364,7 +364,7 @@ and the attention coefficients will be calculated as
364
364
- `concat`: Concatenate layer output or not. If not, layer output is averaged over the heads. Default `true`.
365
365
- `negative_slope`: The parameter of LeakyReLU.Default `0.2`.
366
366
- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `true`.
367
-
367
+ - `dropout`: Dropout probability on the normalized attention coefficient. Default `0.0`.
368
368
369
369
# Examples
370
370
@@ -384,7 +384,7 @@ l = GATConv(in_channel => out_channel, add_self_loops = false, bias = false; hea
384
384
y = l(g, x)
385
385
```
386
386
"""
387
- struct GATConv{DX <: Dense , DE <: Union{Dense, Nothing} , T, A <: AbstractMatrix , F, B} < :
387
+ struct GATConv{DX <: Dense , DE <: Union{Dense, Nothing} , DV, T, A <: AbstractMatrix , F, B} < :
388
388
GNNLayer
389
389
dense_x:: DX
390
390
dense_e:: DE
@@ -396,6 +396,7 @@ struct GATConv{DX <: Dense, DE <: Union{Dense, Nothing}, T, A <: AbstractMatrix,
396
396
heads:: Int
397
397
concat:: Bool
398
398
add_self_loops:: Bool
399
+ dropout:: DV
399
400
end
400
401
401
402
@functor GATConv
@@ -405,7 +406,7 @@ GATConv(ch::Pair{Int, Int}, args...; kws...) = GATConv((ch[1], 0) => ch[2], args
405
406
406
407
function GATConv (ch:: Pair{NTuple{2, Int}, Int} , σ = identity;
407
408
heads:: Int = 1 , concat:: Bool = true , negative_slope = 0.2 ,
408
- init = glorot_uniform, bias:: Bool = true , add_self_loops = true )
409
+ init = glorot_uniform, bias:: Bool = true , add_self_loops = true , dropout = 0.0 )
409
410
(in, ein), out = ch
410
411
if add_self_loops
411
412
@assert ein== 0 " Using edge features and setting add_self_loops=true at the same time is not yet supported."
@@ -416,7 +417,7 @@ function GATConv(ch::Pair{NTuple{2, Int}, Int}, σ = identity;
416
417
b = bias ? Flux. create_bias (dense_x. weight, true , concat ? out * heads : out) : false
417
418
a = init (ein > 0 ? 3 out : 2 out, heads)
418
419
negative_slope = convert (Float32, negative_slope)
419
- GATConv (dense_x, dense_e, b, a, σ, negative_slope, ch, heads, concat, add_self_loops)
420
+ GATConv (dense_x, dense_e, b, a, σ, negative_slope, ch, heads, concat, add_self_loops, dropout )
420
421
end
421
422
422
423
(l:: GATConv )(g:: GNNGraph ) = GNNGraph (g, ndata = l (g, node_features (g), edge_features (g)))
@@ -448,6 +449,7 @@ function (l::GATConv)(g::AbstractGNNGraph, x,
448
449
# a hand-written message passing
449
450
m = apply_edges ((xi, xj, e) -> message (l, xi, xj, e), g, Wxi, Wxj, e)
450
451
α = softmax_edge_neighbors (g, m. logα)
452
+ α = dropout (α, l. dropout)
451
453
β = α .* m. Wxj
452
454
x = aggregate_neighbors (g, + , β)
453
455
@@ -518,6 +520,7 @@ and the attention coefficients will be calculated as
518
520
- `concat`: Concatenate layer output or not. If not, layer output is averaged over the heads. Default `true`.
519
521
- `negative_slope`: The parameter of LeakyReLU.Default `0.2`.
520
522
- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `true`.
523
+ - `dropout`: Dropout probability on the normalized attention coefficient. Default `0.0`.
521
524
522
525
# Examples
523
526
```julia
@@ -540,7 +543,7 @@ e = randn(Float32, ein, length(s))
540
543
y = l(g, x, e)
541
544
```
542
545
"""
543
- struct GATv2Conv{T, A1, A2, A3, B, C <: AbstractMatrix , F} <: GNNLayer
546
+ struct GATv2Conv{T, A1, A2, A3, DV, B, C <: AbstractMatrix , F} <: GNNLayer
544
547
dense_i:: A1
545
548
dense_j:: A2
546
549
dense_e:: A3
@@ -552,6 +555,7 @@ struct GATv2Conv{T, A1, A2, A3, B, C <: AbstractMatrix, F} <: GNNLayer
552
555
heads:: Int
553
556
concat:: Bool
554
557
add_self_loops:: Bool
558
+ dropout:: DV
555
559
end
556
560
557
561
@functor GATv2Conv
@@ -568,7 +572,8 @@ function GATv2Conv(ch::Pair{NTuple{2, Int}, Int},
568
572
negative_slope = 0.2 ,
569
573
init = glorot_uniform,
570
574
bias:: Bool = true ,
571
- add_self_loops = true )
575
+ add_self_loops = true ,
576
+ dropout= 0.0 )
572
577
(in, ein), out = ch
573
578
574
579
if add_self_loops
@@ -586,7 +591,7 @@ function GATv2Conv(ch::Pair{NTuple{2, Int}, Int},
586
591
a = init (out, heads)
587
592
negative_slope = convert (eltype (dense_i. weight), negative_slope)
588
593
GATv2Conv (dense_i, dense_j, dense_e, b, a, σ, negative_slope, ch, heads, concat,
589
- add_self_loops)
594
+ add_self_loops, dropout )
590
595
end
591
596
592
597
(l:: GATv2Conv )(g:: GNNGraph ) = GNNGraph (g, ndata = l (g, node_features (g), edge_features (g)))
@@ -611,6 +616,7 @@ function (l::GATv2Conv)(g::AbstractGNNGraph, x,
611
616
612
617
m = apply_edges ((xi, xj, e) -> message (l, xi, xj, e), g, Wxi, Wxj, e)
613
618
α = softmax_edge_neighbors (g, m. logα)
619
+ α = dropout (α, l. dropout)
614
620
β = α .* m. Wxj
615
621
x = aggregate_neighbors (g, + , β)
616
622
0 commit comments