3939
4040(l:: GlobalPool )(g:: GNNGraph , x:: AbstractArray , ps, st) =  GNNlib. global_pool (l, g, x), st
4141
42- (l:: GlobalPool )(g:: GNNGraph ) =  GNNGraph (g, gdata =  l (g, node_features (g), ps, st))
42+ (l:: GlobalPool )(g:: GNNGraph ) =  GNNGraph (g, gdata =  l (g, node_features (g), ps, st))
43+ 
44+ @doc  raw """ 
45+     GlobalAttentionPool(fgate, ffeat=identity) 
46+ 
47+ Global soft attention layer from the [Gated Graph Sequence Neural 
48+ Networks](https://arxiv.org/abs/1511.05493) paper 
49+ 
50+ ```math 
51+ \m athbf{u}_V = \s um_{i\i n V} \a lpha_i\,  f_{feat}(\m athbf{x}_i)
52+ ``` 
53+ 
54+ where the coefficients ``\a lpha_i`` are given by a [`softmax_nodes`](@ref) 
55+ operation: 
56+ 
57+ ```math 
58+ \a lpha_i = \f rac{e^{f_{gate}(\m athbf{x}_i)}}
59+                 {\s um_{i'\i n V} e^{f_{gate}(\m athbf{x}_{i'})}}. 
60+ ``` 
61+ 
62+ # Arguments 
63+ 
64+ - `fgate`: The function ``f_{gate}: \m athbb{R}^{D_{in}} \t o \m athbb{R}``.  
65+            It is typically expressed by a neural network. 
66+ 
67+ - `ffeat`: The function ``f_{feat}: \m athbb{R}^{D_{in}} \t o \m athbb{R}^{D_{out}}``.  
68+            It is typically expressed by a neural network. 
69+ 
70+ # Examples 
71+ 
72+ ```julia 
73+ using Graphs, LuxCore, Lux, GNNLux, Random 
74+ 
75+ rng = Random.default_rng() 
76+ chin = 6 
77+ chout = 5     
78+ 
79+ fgate = Dense(chin, 1) 
80+ ffeat = Dense(chin, chout) 
81+ pool = GlobalAttentionPool(fgate, ffeat) 
82+ 
83+ g = batch([GNNGraph(Graphs.random_regular_graph(10, 4),  
84+                          ndata=rand(Float32, chin, 10))  
85+                 for i=1:3]) 
86+ 
87+ ps = (fgate = LuxCore.initialparameters(rng, fgate), ffeat = LuxCore.initialparameters(rng, ffeat)) 
88+ st = (fgate = LuxCore.initialstates(rng, fgate), ffeat = LuxCore.initialstates(rng, ffeat)) 
89+ 
90+ u, st = pool(g, g.ndata.x, ps, st) 
91+ 
92+ @assert size(u) == (chout, g.num_graphs) 
93+ ``` 
94+ """ 
95+ struct  GlobalAttentionPool{G, F}
96+     fgate:: G 
97+     ffeat:: F 
98+ end 
99+ 
100+ GlobalAttentionPool (fgate) =  GlobalAttentionPool (fgate, identity)
101+ 
102+ function  (l:: GlobalAttentionPool )(g, x, ps, st)
103+     fgate =  StatefulLuxLayer {true} (l. fgate, ps. fgate, _getstate (st, :fgate ))
104+     ffeat =  StatefulLuxLayer {true} (l. ffeat, ps. ffeat, _getstate (st, :ffeat ))
105+     m =  (; fgate, ffeat)
106+     return  GNNlib. global_attention_pool (m, g, x), st
107+ end 
108+ 
109+ (l:: GlobalAttentionPool )(g:: GNNGraph ) =  GNNGraph (g, gdata =  l (g, node_features (g), ps, st))
0 commit comments