Skip to content

Commit 4283e36

Browse files
committed
Add Set2Set pooling layer
1 parent 321dee7 commit 4283e36

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

GNNlib/src/layers/pool.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ topk_index(y::Adjoint, k::Int) = topk_index(y', k)
2929
function set2set_pool(l, g::GNNGraph, x::AbstractMatrix)
3030
n_in = size(x, 1)
3131
qstar = zeros_like(x, (2*n_in, g.num_graphs))
32-
h = zeros_like(l.lstm.Wh, size(l.lstm.Wh, 2))
33-
c = zeros_like(l.lstm.Wh, size(l.lstm.Wh, 2))
32+
h = zeros_like(l.Wh, size(l.Wh, 2))
33+
c = zeros_like(l.Wh, size(l.Wh, 2))
3434
state = (h, c)
3535
for t in 1:l.num_iters
3636
q, state = l.lstm(qstar, state) # [n_in, n_graphs]

GraphNeuralNetworks/src/layers/pool.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,8 @@ function Set2Set(n_in::Int, n_iters::Int, n_layers::Int = 1)
156156
end
157157

158158
function (l::Set2Set)(g, x)
159-
return GNNlib.set2set_pool(l, g, x)
159+
m = (; lstm, Wh = lstm.Wh)
160+
return GNNlib.set2set_pool(m, g, x)
160161
end
161162

162163
(l::Set2Set)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g)))

0 commit comments

Comments
 (0)