Skip to content

Commit 628da46

Browse files
fix lstm
1 parent b7b773e commit 628da46

File tree

6 files changed

+18
-20
lines changed

6 files changed

+18
-20
lines changed

GNNLux/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GNNLux"
22
uuid = "e8545f4d-a905-48ac-a8c4-ca114b98986d"
33
authors = ["Carlo Lucibello and contributors"]
4-
version = "0.1.1"
4+
version = "0.2.0-DEV"
55

66
[deps]
77
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
@@ -18,7 +18,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1818
[compat]
1919
ConcreteStructs = "0.2.3"
2020
GNNGraphs = "1.3"
21-
GNNlib = "0.2.3"
21+
GNNlib = "1"
2222
Lux = "1"
2323
LuxCore = "1"
2424
NNlib = "0.9.21"

GNNlib/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GNNlib"
22
uuid = "a6a84749-d869-43f8-aacc-be26a1996e48"
33
authors = ["Carlo Lucibello and contributors"]
4-
version = "0.2.5"
4+
version = "1.0.0-DEV"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

GNNlib/src/layers/pool.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,11 @@ topk_index(y::Adjoint, k::Int) = topk_index(y', k)
2929
function set2set_pool(l, g::GNNGraph, x::AbstractMatrix)
3030
n_in = size(x, 1)
3131
qstar = zeros_like(x, (2*n_in, g.num_graphs))
32+
h = zeros_like(l.lstm.Wh, size(l.lstm.Wh, 2))
33+
c = zeros_like(l.lstm.Wh, size(l.lstm.Wh, 2))
3234
for t in 1:l.num_iters
33-
q = l.lstm(qstar) # [n_in, n_graphs]
35+
h, c = l.lstm(qstar, (h, c)) # [n_in, n_graphs]
36+
q = h
3437
qn = broadcast_nodes(g, q) # [n_in, n_nodes]
3538
α = softmax_nodes(g, sum(qn .* x, dims = 1)) # [1, n_nodes]
3639
r = reduce_nodes(+, g, x .* α) # [n_in, n_graphs]

GraphNeuralNetworks/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1919
[compat]
2020
ChainRulesCore = "1"
2121
Flux = "0.15"
22-
GNNGraphs = "1.0"
23-
GNNlib = "0.2"
22+
GNNGraphs = "1"
23+
GNNlib = "1"
2424
LinearAlgebra = "1"
2525
MLUtils = "0.4"
2626
MacroTools = "0.5"

GraphNeuralNetworks/src/layers/pool.jl

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -149,24 +149,19 @@ end
149149
Flux.@layer Set2Set
150150

151151
function Set2Set(n_in::Int, n_iters::Int, n_layers::Int = 1)
152-
@assert n_layers >= 1
152+
@assert n_layers == 1 "multiple layers not implemented yet" #TODO
153153
n_out = 2 * n_in
154-
155-
if n_layers == 1
156-
lstm = LSTM(n_out => n_in)
157-
else
158-
layers = [LSTM(n_out => n_in)]
159-
for _ in 2:n_layers
160-
push!(layers, LSTM(n_in => n_in))
161-
end
162-
lstm = Chain(layers...)
163-
end
164-
154+
lstm = LSTMCell(n_out => n_in)
165155
return Set2Set(lstm, n_iters)
166156
end
167157

158+
function initialstates(cell::LSTMCell)
159+
h = zeros_like(cell.Wh, size(cell.Wh, 2))
160+
c = zeros_like(cell.Wh, size(cell.Wh, 2))
161+
return h, c
162+
end
163+
168164
function (l::Set2Set)(g, x)
169-
Flux.reset!(l.lstm)
170165
return GNNlib.set2set_pool(l, g, x)
171166
end
172167

GraphNeuralNetworks/test/layers/pool.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ end
7676

7777
n_in = 3
7878
n_iters = 2
79-
n_layers = 1
79+
n_layers = 1 #TODO test with more layers
8080
g = batch([rand_graph(10, 40, graph_type = GRAPH_T) for _ in 1:5])
8181
g = GNNGraph(g, ndata = rand(Float32, n_in, g.num_nodes))
8282
l = Set2Set(n_in, n_iters, n_layers)

0 commit comments

Comments
 (0)