60
60
61
61
Flux. functor (:: Type{<:GNNChain} , c) = c. layers, ls -> GNNChain (ls... )
62
62
63
+ # input from graph
64
+ applylayer (l, g:: GNNGraph ) = GNNGraph (g, ndata= l (node_features (g)))
65
+ applylayer (l:: GNNLayer , g:: GNNGraph ) = l (g)
66
+
67
+ # explicit input
63
68
applylayer (l, g:: GNNGraph , x) = l (x)
64
69
applylayer (l:: GNNLayer , g:: GNNGraph , x) = l (g, x)
65
70
@@ -68,11 +73,17 @@ applylayer(l::Parallel, g::GNNGraph, x::AbstractArray) = mapreduce(f -> applylay
68
73
applylayer (l:: Parallel , g:: GNNGraph , xs:: Vararg{<:AbstractArray} ) = mapreduce ((f, x) -> applylayer (f, g, x), l. connection, l. layers, xs)
69
74
applylayer (l:: Parallel , g:: GNNGraph , xs:: Tuple ) = applylayer (l, g, xs... )
70
75
76
+ # input from graph
77
+ applychain (:: Tuple{} , g:: GNNGraph ) = g
78
+ applychain (fs:: Tuple , g:: GNNGraph ) = applychain (tail (fs), applylayer (first (fs), g))
71
79
80
+ # explicit input
72
81
applychain (:: Tuple{} , g:: GNNGraph , x) = x
73
82
applychain (fs:: Tuple , g:: GNNGraph , x) = applychain (tail (fs), g, applylayer (first (fs), g, x))
74
83
75
84
(c:: GNNChain )(g:: GNNGraph , x) = applychain (Tuple (c. layers), g, x)
85
+ (c:: GNNChain )(g:: GNNGraph ) = applychain (Tuple (c. layers), g)
86
+
76
87
77
88
Base. getindex (c:: GNNChain , i:: AbstractArray ) = GNNChain (c. layers[i]. .. )
78
89
Base. getindex (c:: GNNChain{<:NamedTuple} , i:: AbstractArray ) =
0 commit comments