Skip to content

Commit cc387e4

Browse files
fix gatedgraphconv
1 parent 6a23a70 commit cc387e4

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

GNNLux/src/layers/conv.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1261,7 +1261,11 @@ LuxCore.parameterlength(l::GatedGraphConv) = parameterlength(l.gru) + l.dims^2*l
12611261

12621262
function (l::GatedGraphConv)(g, x, ps, st)
12631263
gru = StatefulLuxLayer{true}(l.gru, ps.gru, _getstate(st, :gru))
1264-
fgru = (x, h) -> gru((x, (h,)))[1] # make the forward compatible with Flux.GRUCell style
1264+
# make the forward compatible with Flux.GRUCell style
1265+
function fgru(x, h)
1266+
y, (h, ) = gru((x, (h,)))
1267+
return y, h
1268+
end
12651269
m = (; gru=fgru, ps.weight, l.num_layers, l.aggr, l.dims)
12661270
return GNNlib.gated_graph_conv(m, g, x), st
12671271
end

GNNlib/src/layers/conv.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,7 @@ function gated_graph_conv(l, g::GNNGraph, x::AbstractMatrix)
227227
for i in 1:(l.num_layers)
228228
m = view(l.weight, :, :, i) * h
229229
m = propagate(copy_xj, g, l.aggr; xj = m)
230-
# in gru forward, hidden state is first argument, input is second
231-
h = l.gru(m, h)
230+
_, h = l.gru(m, h)
232231
end
233232
return h
234233
end

0 commit comments

Comments
 (0)