Skip to content

Commit 6df678b

Browse files
committed
Remove Set2Set not working
1 parent c4986f6 commit 6df678b

File tree

1 file changed

+0
-38
lines changed

1 file changed

+0
-38
lines changed

GNNLux/src/layers/pool.jl

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -131,41 +131,3 @@ function TopKPool(adj::AbstractMatrix, k::Int, in_channel::Int; init = glorot_un
131131
end
132132

133133
(t::TopKPool)(x::AbstractArray, ps, st) = GNNlib.topk_pool(t, x)
134-
135-
136-
@doc raw"""
137-
Set2Set(n_in, n_iters, n_layers = 1)
138-
139-
Set2Set layer from the paper [Order Matters: Sequence to sequence for sets](https://arxiv.org/abs/1511.06391).
140-
141-
For each graph in the batch, the layer computes an output vector of size `2*n_in` by iterating the following steps `n_iters` times:
142-
```math
143-
\mathbf{q} = \mathrm{LSTM}(\mathbf{q}_{t-1}^*)
144-
\alpha_{i} = \frac{\exp(\mathbf{q}^T \mathbf{x}_i)}{\sum_{j=1}^N \exp(\mathbf{q}^T \mathbf{x}_j)}
145-
\mathbf{r} = \sum_{i=1}^N \alpha_{i} \mathbf{x}_i
146-
\mathbf{q}^*_t = [\mathbf{q}; \mathbf{r}]
147-
```
148-
where `N` is the number of nodes in the graph, `LSTM` is a Long-Short-Term-Memory network with `n_layers` layers, input size `2*n_in` and output size `n_in`.
149-
150-
Given a batch of graphs `g` and node features `x`, the layer returns a matrix of size `(2*n_in, n_graphs)`.
151-
```
152-
"""
153-
struct Set2Set{L} <: GNNContainerLayer{(:lstm,)}
154-
lstm::L
155-
num_iters::Int
156-
end
157-
158-
function Set2Set(n_in::Int, n_iters::Int, n_layers::Int = 1)
159-
@assert n_layers == 1 "multiple layers not implemented yet" #TODO
160-
n_out = 2 * n_in
161-
lstm = Lux.LSTMCell(n_out => n_in)
162-
return Set2Set(lstm, n_iters)
163-
end
164-
165-
function (l::Set2Set)(g, x, ps, st)
166-
lstm = StatefulLuxLayer{true}(l.lstm, ps.lstm, _getstate(st, :lstm))
167-
m = (; lstm, Wh = ps.lstm.weight_hh)
168-
return GNNlib.set2set_pool(m, g, x)
169-
end
170-
171-
(l::Set2Set)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st))

0 commit comments

Comments
 (0)