Skip to content

Commit e23036a

Browse files
use readout_nodes in pooling
1 parent dc792d5 commit e23036a

File tree

1 file changed

+2
-8
lines changed

1 file changed

+2
-8
lines changed

src/layers/pool.jl

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,8 @@ struct GlobalPool{F} <: GNNLayer
3333
aggr::F
3434
end
3535

36-
function (l::GlobalPool)(g::GNNGraph, X::AbstractArray)
37-
if isnothing(g.graph_indicator)
38-
# assume only one graph
39-
indexes = fill!(similar(X, Int, g.num_nodes), 1)
40-
else
41-
indexes = g.graph_indicator
42-
end
43-
return NNlib.scatter(l.aggr, X, indexes)
36+
function (l::GlobalPool)(g::GNNGraph, x::AbstractArray)
37+
return readout_nodes(g, x, l.aggr)
4438
end
4539

4640
"""

0 commit comments

Comments
 (0)