|
60 | 60 |
|
61 | 61 | Flux.functor(::Type{<:GNNChain}, c) = c.layers, ls -> GNNChain(ls...)
|
62 | 62 |
|
| 63 | +# input from graph |
| 64 | +applylayer(l, g::GNNGraph) = GNNGraph(g, ndata=l(node_features(g))) |
| 65 | +applylayer(l::GNNLayer, g::GNNGraph) = l(g) |
| 66 | + |
| 67 | +# explicit input |
63 | 68 | applylayer(l, g::GNNGraph, x) = l(x)
|
64 | 69 | applylayer(l::GNNLayer, g::GNNGraph, x) = l(g, x)
|
65 | 70 |
|
66 | 71 | # Handle Flux.Parallel
|
| 72 | +applylayer(l::Parallel, g::GNNGraph) = GNNGraph(g, ndata=applylayer(l, g, node_features(g))) |
67 | 73 | applylayer(l::Parallel, g::GNNGraph, x::AbstractArray) = mapreduce(f -> applylayer(f, g, x), l.connection, l.layers)
|
68 |
| -applylayer(l::Parallel, g::GNNGraph, xs::Vararg{<:AbstractArray}) = mapreduce((f, x) -> applylayer(f, g, x), l.connection, l.layers, xs) |
69 |
| -applylayer(l::Parallel, g::GNNGraph, xs::Tuple) = applylayer(l, g, xs...) |
70 | 74 |
|
| 75 | +# input from graph |
| 76 | +applychain(::Tuple{}, g::GNNGraph) = g |
| 77 | +applychain(fs::Tuple, g::GNNGraph) = applychain(tail(fs), applylayer(first(fs), g)) |
71 | 78 |
|
| 79 | +# explicit input |
72 | 80 | applychain(::Tuple{}, g::GNNGraph, x) = x
|
73 | 81 | applychain(fs::Tuple, g::GNNGraph, x) = applychain(tail(fs), g, applylayer(first(fs), g, x))
|
74 | 82 |
|
75 | 83 | (c::GNNChain)(g::GNNGraph, x) = applychain(Tuple(c.layers), g, x)
|
| 84 | +(c::GNNChain)(g::GNNGraph) = applychain(Tuple(c.layers), g) |
| 85 | + |
76 | 86 |
|
77 | 87 | Base.getindex(c::GNNChain, i::AbstractArray) = GNNChain(c.layers[i]...)
|
78 | 88 | Base.getindex(c::GNNChain{<:NamedTuple}, i::AbstractArray) =
|
|
0 commit comments