Skip to content

Commit 7a3540e

Browse files
update nnlib
1 parent 2dd14fd commit 7a3540e

File tree

3 files changed

+3
-8
lines changed

3 files changed

+3
-8
lines changed

GNNlib/Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ version = "1.0.0"
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
88
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
99
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
10-
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1110
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1211
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1312
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
@@ -32,7 +31,6 @@ CUDA = "5"
3231
ChainRulesCore = "1.24"
3332
DataStructures = "0.18"
3433
GNNGraphs = "1.4"
35-
GPUArraysCore = "0.1"
3634
LinearAlgebra = "1"
3735
MLUtils = "0.4"
3836
NNlib = "0.9"

GNNlib/src/layers/pool.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ topk_index(y::Adjoint, k::Int) = topk_index(y', k)
2929
function set2set_pool(l, g::GNNGraph, x::AbstractMatrix)
3030
n_in = size(x, 1)
3131
qstar = zeros_like(x, (2*n_in, g.num_graphs))
32-
h = zeros_like(l.Wh, size(l.Wh, 2))
33-
c = zeros_like(l.Wh, size(l.Wh, 2))
32+
h = zeros_like(l.lstm.Wh, size(l.lstm.Wh, 2))
33+
c = zeros_like(l.lstm.Wh, size(l.lstm.Wh, 2))
3434
state = (h, c)
3535
for t in 1:l.num_iters
3636
q, state = l.lstm(qstar, state) # [n_in, n_graphs]

GraphNeuralNetworks/src/layers/pool.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,6 @@ function Set2Set(n_in::Int, n_iters::Int, n_layers::Int = 1)
155155
return Set2Set(lstm, n_iters)
156156
end
157157

158-
function (l::Set2Set)(g, x)
159-
m = (; l.lstm, l.num_iters, Wh = l.lstm.Wh)
160-
return GNNlib.set2set_pool(m, g, x)
161-
end
158+
(l::Set2Set)(g, x) = GNNlib.set2set_pool(l, g, x)
162159

163160
(l::Set2Set)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g)))

0 commit comments

Comments
 (0)