39
39
40
40
(l:: GlobalPool )(g:: GNNGraph , x:: AbstractArray , ps, st) = GNNlib. global_pool (l, g, x), st
41
41
42
- (l:: GlobalPool )(g:: GNNGraph ) = GNNGraph (g, gdata = l (g, node_features (g), ps, st))
42
+ (l:: GlobalPool )(g:: GNNGraph ) = GNNGraph (g, gdata = l (g, node_features (g), ps, st))
43
+
44
+ @doc raw """
45
+ GlobalAttentionPool(fgate, ffeat=identity)
46
+
47
+ Global soft attention layer from the [Gated Graph Sequence Neural
48
+ Networks](https://arxiv.org/abs/1511.05493) paper
49
+
50
+ ```math
51
+ \m athbf{u}_V = \s um_{i\i n V} \a lpha_i\, f_{feat}(\m athbf{x}_i)
52
+ ```
53
+
54
+ where the coefficients ``\a lpha_i`` are given by a [`softmax_nodes`](@ref)
55
+ operation:
56
+
57
+ ```math
58
+ \a lpha_i = \f rac{e^{f_{gate}(\m athbf{x}_i)}}
59
+ {\s um_{i'\i n V} e^{f_{gate}(\m athbf{x}_{i'})}}.
60
+ ```
61
+
62
+ # Arguments
63
+
64
+ - `fgate`: The function ``f_{gate}: \m athbb{R}^{D_{in}} \t o \m athbb{R}``.
65
+ It is typically expressed by a neural network.
66
+
67
+ - `ffeat`: The function ``f_{feat}: \m athbb{R}^{D_{in}} \t o \m athbb{R}^{D_{out}}``.
68
+ It is typically expressed by a neural network.
69
+
70
+ # Examples
71
+
72
+ ```julia
73
+ using Graphs, LuxCore, Lux, GNNLux, Random
74
+
75
+ rng = Random.default_rng()
76
+ chin = 6
77
+ chout = 5
78
+
79
+ fgate = Dense(chin, 1)
80
+ ffeat = Dense(chin, chout)
81
+ pool = GlobalAttentionPool(fgate, ffeat)
82
+
83
+ g = batch([GNNGraph(Graphs.random_regular_graph(10, 4),
84
+ ndata=rand(Float32, chin, 10))
85
+ for i=1:3])
86
+
87
+ ps = (fgate = LuxCore.initialparameters(rng, fgate), ffeat = LuxCore.initialparameters(rng, ffeat))
88
+ st = (fgate = LuxCore.initialstates(rng, fgate), ffeat = LuxCore.initialstates(rng, ffeat))
89
+
90
+ u, st = pool(g, g.ndata.x, ps, st)
91
+
92
+ @assert size(u) == (chout, g.num_graphs)
93
+ ```
94
+ """
95
+ struct GlobalAttentionPool{G, F}
96
+ fgate:: G
97
+ ffeat:: F
98
+ end
99
+
100
+ GlobalAttentionPool (fgate) = GlobalAttentionPool (fgate, identity)
101
+
102
+ function (l:: GlobalAttentionPool )(g, x, ps, st)
103
+ fgate = StatefulLuxLayer {true} (l. fgate, ps. fgate, _getstate (st, :fgate ))
104
+ ffeat = StatefulLuxLayer {true} (l. ffeat, ps. ffeat, _getstate (st, :ffeat ))
105
+ m = (; fgate, ffeat)
106
+ return GNNlib. global_attention_pool (m, g, x), st
107
+ end
108
+
109
+ (l:: GlobalAttentionPool )(g:: GNNGraph ) = GNNGraph (g, gdata = l (g, node_features (g), ps, st))
0 commit comments