@@ -106,4 +106,66 @@ function (l::GlobalAttentionPool)(g, x, ps, st)
106106    return  GNNlib. global_attention_pool (m, g, x), st
107107end 
108108
109- (l:: GlobalAttentionPool )(g:: GNNGraph ) =  GNNGraph (g, gdata =  l (g, node_features (g), ps, st))
109+ (l:: GlobalAttentionPool )(g:: GNNGraph ) =  GNNGraph (g, gdata =  l (g, node_features (g), ps, st))
110+ 
111+ """ 
112+     TopKPool(adj, k, in_channel) 
113+ 
114+ Top-k pooling layer. 
115+ 
116+ # Arguments 
117+ 
118+ - `adj`: Adjacency matrix of a graph. 
119+ - `k`: Top-k nodes are selected to pool together. 
120+ - `in_channel`: The dimension of input channel. 
121+ """ 
122+ struct  TopKPool{T, S}
123+     A:: AbstractMatrix{T} 
124+     k:: Int 
125+     p:: AbstractVector{S} 
126+     Ã:: AbstractMatrix{T} 
127+ end 
128+ 
129+ function  TopKPool (adj:: AbstractMatrix , k:: Int , in_channel:: Int ; init =  glorot_uniform)
130+     TopKPool (adj, k, init (in_channel), similar (adj, k, k))
131+ end 
132+ 
133+ (t:: TopKPool )(x:: AbstractArray , ps, st) =  GNNlib. topk_pool (t, x)
134+ 
135+ 
136+ @doc  raw """ 
137+     Set2Set(n_in, n_iters, n_layers = 1) 
138+ 
139+ Set2Set layer from the paper [Order Matters: Sequence to sequence for sets](https://arxiv.org/abs/1511.06391). 
140+ 
141+ For each graph in the batch, the layer computes an output vector of size `2*n_in` by iterating the following steps `n_iters` times: 
142+ ```math 
143+ \m athbf{q} = \m athrm{LSTM}(\m athbf{q}_{t-1}^*)
144+ \a lpha_{i} = \f rac{\e xp(\m athbf{q}^T \m athbf{x}_i)}{\s um_{j=1}^N \e xp(\m athbf{q}^T \m athbf{x}_j)} 
145+ \m athbf{r} = \s um_{i=1}^N \a lpha_{i} \m athbf{x}_i
146+ \m athbf{q}^*_t = [\m athbf{q}; \m athbf{r}]
147+ ``` 
148+ where `N` is the number of nodes in the graph, `LSTM` is a Long-Short-Term-Memory network with `n_layers` layers, input size `2*n_in` and output size `n_in`. 
149+ 
150+ Given a batch of graphs `g` and node features `x`, the layer returns a matrix of size `(2*n_in, n_graphs)`. 
151+ ``` 
152+ """ 
153+ struct  Set2Set{L} <:  GNNContainerLayer{(:lstm,)} 
154+     lstm:: L 
155+     num_iters:: Int 
156+ end 
157+ 
158+ function  Set2Set (n_in:: Int , n_iters:: Int , n_layers:: Int  =  1 )
159+     @assert  n_layers ==  1  " multiple layers not implemented yet" # TODO 
160+     n_out =  2  *  n_in
161+     lstm =  Lux. LSTMCell (n_out =>  n_in)
162+     return  Set2Set (lstm, n_iters)
163+ end 
164+ 
165+ function  (l:: Set2Set )(g, x, ps, st)
166+     lstm =  StatefulLuxLayer {true} (l. lstm, ps. lstm, _getstate (st, :lstm ))
167+     m =  (; lstm, Wh =  ps. lstm. weight_hh)
168+     return  GNNlib. set2set_pool (m, g, x)
169+ end 
170+ 
171+ (l:: Set2Set )(g:: GNNGraph ) =  GNNGraph (g, gdata =  l (g, node_features (g), ps, st))
0 commit comments