Skip to content

Commit 2f9f0d1

Browse files
committed
Fix pooling layer
1 parent a8ec4b5 commit 2f9f0d1

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

GNNLux/src/layers/pool.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@ pool = GlobalPool(mean)
2525
2626
g = GNNGraph(erdos_renyi(10, 4))
2727
X = rand(32, 10)
28-
pool(g, X) # => 32x1 matrix
28+
pool(g, X, ps, st) # => 32x1 matrix
2929
3030
3131
g = MLUtils.batch([GNNGraph(erdos_renyi(10, 4)) for _ in 1:5])
3232
X = rand(32, 50)
33-
pool(g, X) # => 32x5 matrix
33+
pool(g, X, ps, st) # => 32x5 matrix
3434
```
3535
"""
3636
struct GlobalPool{F} <: GNNLayer
@@ -39,4 +39,4 @@ end
3939

4040
(l::GlobalPool)(g::GNNGraph, x::AbstractArray, ps, st) = GNNlib.global_pool(l, g, x), st
4141

42-
(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g)))
42+
(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st))

0 commit comments

Comments
 (0)