Skip to content

Commit 83207c4

Browse files
fix set2set
1 parent cc387e4 commit 83207c4

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

GNNlib/src/layers/pool.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ function set2set_pool(l, g::GNNGraph, x::AbstractMatrix)
3131
qstar = zeros_like(x, (2*n_in, g.num_graphs))
3232
h = zeros_like(l.lstm.Wh, size(l.lstm.Wh, 2))
3333
c = zeros_like(l.lstm.Wh, size(l.lstm.Wh, 2))
34+
state = (h, c)
3435
for t in 1:l.num_iters
35-
h, c = l.lstm(qstar, (h, c)) # [n_in, n_graphs]
36-
q = h
36+
q, state = l.lstm(qstar, state) # [n_in, n_graphs]
3737
qn = broadcast_nodes(g, q) # [n_in, n_nodes]
3838
α = softmax_nodes(g, sum(qn .* x, dims = 1)) # [1, n_nodes]
3939
r = reduce_nodes(+, g, x .* α) # [n_in, n_graphs]

0 commit comments

Comments
 (0)