1
+ # Missing Layers
2
+
3
+ # | Layer |Sparse Ops|Edge Weight|Edge Features| Heterograph | TemporalSnapshotsGNNGraphs |
4
+ # | :-------- | :---: |:---: |:---: | :---: | :---: |
5
+ # | [`AGNNConv`](@ref) | | | ✓ | | |
6
+ # | [`CGConv`](@ref) | | | ✓ | ✓ | ✓ |
7
+ # | [`EGNNConv`](@ref) | | | ✓ | | |
8
+ # | [`EdgeConv`](@ref) | | | | ✓ | |
9
+ # | [`GATConv`](@ref) | | | ✓ | ✓ | ✓ |
10
+ # | [`GATv2Conv`](@ref) | | | ✓ | ✓ | ✓ |
11
+ # | [`GatedGraphConv`](@ref) | ✓ | | | | ✓ |
12
+ # | [`GINConv`](@ref) | ✓ | | | ✓ | ✓ |
13
+ # | [`GMMConv`](@ref) | | | ✓ | | |
14
+ # | [`MEGNetConv`](@ref) | | | ✓ | | |
15
+ # | [`NNConv`](@ref) | | | ✓ | | |
16
+ # | [`ResGatedGraphConv`](@ref) | | | | ✓ | ✓ |
17
+ # | [`SAGEConv`](@ref) | ✓ | | | ✓ | ✓ |
18
+ # | [`SGConv`](@ref) | ✓ | | | | ✓ |
19
+ # | [`TransformerConv`](@ref) | | | ✓ | | |
20
+
21
+
22
+ @concrete struct GCNConv <: GNNLayer
23
+ in_dims:: Int
24
+ out_dims:: Int
25
+ use_bias:: Bool
26
+ add_self_loops:: Bool
27
+ use_edge_weight:: Bool
28
+ init_weight
29
+ init_bias
30
+ σ
31
+ end
1
32
2
- @doc raw """
3
- GraphConv(in => out, σ=identity; aggr=+, bias=true, init=glorot_uniform)
33
+ function GCNConv (ch:: Pair{Int, Int} , σ = identity;
34
+ init_weight = glorot_uniform,
35
+ init_bias = zeros32,
36
+ use_bias:: Bool = true ,
37
+ add_self_loops:: Bool = true ,
38
+ use_edge_weight:: Bool = false ,
39
+ allow_fast_activation:: Bool = true )
40
+ in_dims, out_dims = ch
41
+ σ = allow_fast_activation ? NNlib. fast_act (σ) : σ
42
+ return GCNConv (in_dims, out_dims, use_bias, add_self_loops, use_edge_weight, init_weight, init_bias, σ)
43
+ end
4
44
5
- Graph convolution layer from Reference: [Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks](https://arxiv.org/abs/1810.02244).
45
+ function LuxCore. initialparameters (rng:: AbstractRNG , l:: GCNConv )
46
+ weight = l. init_weight (rng, l. out_dims, l. in_dims)
47
+ if l. use_bias
48
+ bias = l. init_bias (rng, l. out_dims)
49
+ return (; weight, bias)
50
+ else
51
+ return (; weight)
52
+ end
53
+ end
6
54
7
- Performs:
8
- ```math
9
- \m athbf{x}_i' = W_1 \m athbf{x}_i + \s quare_{j \i n \m athcal{N}(i)} W_2 \m athbf{x}_j
10
- ```
55
+ LuxCore. parameterlength (l:: GCNConv ) = l. use_bias ? l. in_dims * l. out_dims + l. out_dims : l. in_dims * l. out_dims
56
+ LuxCore. statelength (d:: GCNConv ) = 0
57
+ LuxCore. outputsize (d:: GCNConv ) = (d. out_dims,)
11
58
12
- where the aggregation type is selected by `aggr`.
59
+ function Base. show (io:: IO , l:: GCNConv )
60
+ print (io, " GCNConv(" , l. in_dims, " => " , l. out_dims)
61
+ l. σ == identity || print (io, " , " , l. σ)
62
+ l. use_bias || print (io, " , use_bias=false" )
63
+ l. add_self_loops || print (io, " , add_self_loops=false" )
64
+ ! l. use_edge_weight || print (io, " , use_edge_weight=true" )
65
+ print (io, " )" )
66
+ end
13
67
14
- # Arguments
68
+ # TODO norm_fn should be keyword argument only
69
+ (l:: GCNConv )(g, x, ps, st; conv_weight= nothing , edge_weight= nothing , norm_fn= d -> 1 ./ sqrt .(d)) =
70
+ l (g, x, edge_weight, norm_fn, ps, st; conv_weight)
71
+ (l:: GCNConv )(g, x, edge_weight, ps, st; conv_weight= nothing , norm_fn = d -> 1 ./ sqrt .(d)) =
72
+ l (g, x, edge_weight, norm_fn, ps, st; conv_weight)
73
+ (l:: GCNConv )(g, x, edge_weight, norm_fn, ps, st; conv_weight= nothing ) =
74
+ GNNlib. gcn_conv (l, g, x, edge_weight, norm_fn, conv_weight, ps), st
15
75
16
- - `in`: The dimension of input features.
17
- - `out`: The dimension of output features.
18
- - `σ`: Activation function.
19
- - `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
20
- - `bias`: Add learnable bias.
21
- - `init`: Weights' initializer.
76
+ @concrete struct ChebConv <: GNNLayer
77
+ in_dims:: Int
78
+ out_dims:: Int
79
+ use_bias:: Bool
80
+ k:: Int
81
+ init_weight
82
+ init_bias
83
+ σ
84
+ end
22
85
23
- # Examples
86
+ function ChebConv (ch:: Pair{Int, Int} , k:: Int , σ = identity;
87
+ init_weight = glorot_uniform,
88
+ init_bias = zeros32,
89
+ use_bias:: Bool = true ,
90
+ allow_fast_activation:: Bool = true )
91
+ in_dims, out_dims = ch
92
+ σ = allow_fast_activation ? NNlib. fast_act (σ) : σ
93
+ return ChebConv (in_dims, out_dims, use_bias, k, init_weight, init_bias, σ)
94
+ end
24
95
25
- ```julia
26
- # create data
27
- s = [1,1,2,3]
28
- t = [2,3,1,1]
29
- in_channel = 3
30
- out_channel = 5
31
- g = GNNGraph(s, t)
32
- x = randn(Float32, 3, g.num_nodes)
96
+ function LuxCore. initialparameters (rng:: AbstractRNG , l:: ChebConv )
97
+ weight = l. init_weight (rng, l. out_dims, l. in_dims, l. k)
98
+ if l. use_bias
99
+ bias = l. init_bias (rng, l. out_dims)
100
+ return (; weight, bias)
101
+ else
102
+ return (; weight)
103
+ end
104
+ end
105
+
106
+ LuxCore. parameterlength (l:: ChebConv ) = l. use_bias ? l. in_dims * l. out_dims * l. k + l. out_dims :
107
+ l. in_dims * l. out_dims * l. k
108
+ LuxCore. statelength (d:: ChebConv ) = 0
109
+ LuxCore. outputsize (d:: ChebConv ) = (d. out_dims,)
110
+
111
+ function Base. show (io:: IO , l:: ChebConv )
112
+ print (io, " ChebConv(" , l. in_dims, " => " , l. out_dims, " , K=" , l. K)
113
+ l. σ == identity || print (io, " , " , l. σ)
114
+ l. use_bias || print (io, " , use_bias=false" )
115
+ print (io, " )" )
116
+ end
33
117
34
- # create layer
35
- l = GraphConv(in_channel => out_channel, relu, bias = false, aggr = mean)
118
+ (l:: ChebConv )(g, x, ps, st) = GNNlib. cheb_conv (l, g, x, ps), st
36
119
37
- # forward pass
38
- y = l(g, x)
39
- ```
40
- """
41
- @concrete struct GraphConv <: AbstractExplicitLayer
120
+ @concrete struct GraphConv <: GNNLayer
42
121
in_dims:: Int
43
122
out_dims:: Int
44
123
use_bias:: Bool
45
- init_weight:: Function
46
- init_bias:: Function
124
+ init_weight
125
+ init_bias
47
126
σ
48
127
aggr
49
128
end
50
129
51
-
52
130
function GraphConv (ch:: Pair{Int, Int} , σ = identity;
53
131
aggr = + ,
54
132
init_weight = glorot_uniform,
@@ -65,10 +143,10 @@ function LuxCore.initialparameters(rng::AbstractRNG, l::GraphConv)
65
143
weight2 = l. init_weight (rng, l. out_dims, l. in_dims)
66
144
if l. use_bias
67
145
bias = l. init_bias (rng, l. out_dims)
146
+ return (; weight1, weight2, bias)
68
147
else
69
- bias = false
148
+ return (; weight1, weight2)
70
149
end
71
- return (; weight1, weight2, bias)
72
150
end
73
151
74
152
function LuxCore. parameterlength (l:: GraphConv )
@@ -90,4 +168,4 @@ function Base.show(io::IO, l::GraphConv)
90
168
print (io, " )" )
91
169
end
92
170
93
- (l:: GraphConv )(g:: GNNGraph , x, ps, st) = GNNlib. graph_conv (l, g, x, ps), st
171
+ (l:: GraphConv )(g, x, ps, st) = GNNlib. graph_conv (l, g, x, ps), st
0 commit comments