@@ -32,7 +32,7 @@ function GCNConv(ch::Pair{Int,Int}, σ=identity;
32
32
init= glorot_uniform, bias:: Bool = true )
33
33
in, out = ch
34
34
W = init (out, in)
35
- b = Flux. create_bias (W, bias , out)
35
+ b = bias ? Flux. create_bias (W, true , out) : false
36
36
GCNConv (W, b, σ)
37
37
end
38
38
@@ -105,7 +105,7 @@ function ChebConv(ch::Pair{Int,Int}, k::Int;
105
105
init= glorot_uniform, bias:: Bool = true )
106
106
in, out = ch
107
107
W = init (out, in, k)
108
- b = Flux. create_bias (W, bias , out)
108
+ b = bias ? Flux. create_bias (W, true , out) : false
109
109
ChebConv (W, b, k)
110
110
end
111
111
@@ -172,7 +172,7 @@ function GraphConv(ch::Pair{Int,Int}, σ=identity, aggr=+;
172
172
in, out = ch
173
173
W1 = init (out, in)
174
174
W2 = init (out, in)
175
- b = Flux. create_bias (W1, bias , out)
175
+ b = bias ? Flux. create_bias (W1, true , out) : false
176
176
GraphConv (W1, W2, b, σ, aggr)
177
177
end
178
178
196
196
197
197
198
198
@doc raw """
199
- GATConv(in => out, , σ=identity;
199
+ GATConv(in => out, σ=identity;
200
200
heads=1,
201
201
concat=true,
202
202
init=glorot_uniform
@@ -224,7 +224,7 @@ with ``z_i`` a normalization factor.
224
224
- `concat`: Concatenate layer output or not. If not, layer output is averaged over the heads.
225
225
- `negative_slope::Real`: Keyword argument, the parameter of LeakyReLU.
226
226
"""
227
- struct GATConv{T, A<: AbstractMatrix{T} , B} <: GNNLayer
227
+ struct GATConv{T, A<: AbstractMatrix , B} <: GNNLayer
228
228
weight:: A
229
229
bias:: B
230
230
a:: A
@@ -239,12 +239,13 @@ end
239
239
Flux. trainable (l:: GATConv ) = (l. weight, l. bias, l. a)
240
240
241
241
function GATConv (ch:: Pair{Int,Int} , σ= identity;
242
- heads:: Int = 1 , concat:: Bool = true , negative_slope= 0.2f0 ,
242
+ heads:: Int = 1 , concat:: Bool = true , negative_slope= 0.2 ,
243
243
init= glorot_uniform, bias:: Bool = true )
244
244
in, out = ch
245
245
W = init (out* heads, in)
246
- b = Flux. create_bias (W, bias , out* heads)
246
+ b = bias ? Flux. create_bias (W, true , out* heads) : false
247
247
a = init (2 * out, heads)
248
+ negative_slope = convert (eltype (W), negative_slope)
248
249
GATConv (W, b, a, σ, negative_slope, ch, heads, concat)
249
250
end
250
251
@@ -356,20 +357,20 @@ end
356
357
357
358
358
359
@doc raw """
359
- EdgeConv(f ; aggr=max)
360
+ EdgeConv(nn ; aggr=max)
360
361
361
362
Edge convolutional layer from paper [Dynamic Graph CNN for Learning on Point Clouds](https://arxiv.org/abs/1801.07829).
362
363
363
364
Performs the operation
364
365
```math
365
- \m athbf{x}_i' = \s quare_{j \i n N(i)} f (\m athbf{x}_i || \m athbf{x}_j - \m athbf{x}_i)
366
+ \m athbf{x}_i' = \s quare_{j \i n N(i)} nn (\m athbf{x}_i || \m athbf{x}_j - \m athbf{x}_i)
366
367
```
367
368
368
- where `f` typically denotes a learnable function, e.g. a linear layer or a multi-layer perceptron.
369
+ where `nn` generally denotes a learnable function, e.g. a linear layer or a multi-layer perceptron.
369
370
370
371
# Arguments
371
372
372
- - `f `: A (possibly learnable) function acting on edge features.
373
+ - `nn `: A (possibly learnable) function acting on edge features.
373
374
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
374
375
"""
375
376
struct EdgeConv <: GNNLayer
@@ -405,9 +406,9 @@ Graph Isomorphism convolutional layer from paper [How Powerful are Graph Neural
405
406
406
407
407
408
```math
408
- \m athbf{x}_i' = f \l eft((1 + \e psilon) \m athbf{x}_i + \s um_{j \i n N(i)} \m athbf{x}_j \r ight)
409
+ \m athbf{x}_i' = f_ \T heta \l eft((1 + \e psilon) \m athbf{x}_i + \s um_{j \i n N(i)} \m athbf{x}_j \r ight)
409
410
```
410
- where `f ` typically denotes a learnable function, e.g. a linear layer or a multi-layer perceptron.
411
+ where ``f_ \T heta` ` typically denotes a learnable function, e.g. a linear layer or a multi-layer perceptron.
411
412
412
413
# Arguments
413
414
@@ -434,3 +435,77 @@ function (l::GINConv)(g::GNNGraph, X::AbstractMatrix)
434
435
X, _ = propagate (l, g, + , X)
435
436
X
436
437
end
438
+
439
+
440
+ @doc raw """
441
+ NNConv(in => out, f, σ=identity; aggr=+, bias=true, init=glorot_uniform)
442
+
443
+ The continuous kernel-based convolutional operator from the
444
+ [Neural Message Passing for Quantum Chemistry](https://arxiv.org/abs/1704.01212) paper.
445
+ This convolution is also known as the edge-conditioned convolution from the
446
+ [Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs](https://arxiv.org/abs/1704.02901) paper.
447
+
448
+ Performs the operation
449
+
450
+ ```math
451
+ \m athbf{x}_i' = W \m athbf{x}_i + \s quare_{j \i n N(i)} f_\T heta(\m athbf{e}_{j\t o i})\,\m athbf{x}_j
452
+ ```
453
+
454
+ where ``f_\T heta`` denotes a learnable function (e.g. a linear layer or a multi-layer perceptron).
455
+ Given an input of batched edge features `e` of size `(num_edge_features, num_edges)`,
456
+ the function `f` will return an batched matrices array whose size is `(out, in, num_edges)`.
457
+ For convenience, also functions returning a single `(out*in, num_edges)` matrix are allowed.
458
+
459
+ # Arguments
460
+
461
+ - `in`: The dimension of input features.
462
+ - `out`: The dimension of output features.
463
+ - `f`: A (possibly learnable) function acting on edge features.
464
+ - `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
465
+ - `σ`: Activation function.
466
+ - `bias`: Add learnable bias.
467
+ - `init`: Weights' initializer.
468
+ """
469
+ struct NNConv <: GNNLayer
470
+ weight
471
+ bias
472
+ nn
473
+ σ
474
+ aggr
475
+ end
476
+
477
+ @functor NNConv
478
+
479
+ function NNConv (ch:: Pair{Int,Int} , nn, σ= identity; aggr= + , bias= true , init= glorot_uniform)
480
+ in, out = ch
481
+ W = init (out, in)
482
+ b = bias ? Flux. create_bias (W, true , out) : false
483
+ return NNConv (W, b, nn, σ, aggr)
484
+ end
485
+
486
+ function compute_message (l:: NNConv , x_i, x_j, e_ij)
487
+ nin, nedges = size (x_i)
488
+ W = reshape (l. nn (e_ij), (:, nin, nedges))
489
+ x_j = reshape (x_j, (nin, 1 , nedges)) # needed by batched_mul
490
+ m = NNlib. batched_mul (W, x_j)
491
+ return reshape (m, :, nedges)
492
+ end
493
+
494
+ function update_node (l:: NNConv , m, x)
495
+ l. σ .(l. weight* x .+ m .+ l. bias)
496
+ end
497
+
498
+ function (l:: NNConv )(g:: GNNGraph , x:: AbstractMatrix , e)
499
+ check_num_nodes (g, x)
500
+ x, _ = propagate (l, g, l. aggr, x, e)
501
+ return x
502
+ end
503
+
504
+ (l:: NNConv )(g:: GNNGraph ) = GNNGraph (g, ndata= l (g, node_features (g), edge_features (g)))
505
+
506
+ function Base. show (io:: IO , l:: NNConv )
507
+ out, in = size (l. weight)
508
+ print (io, " NNConv( $in => $out " )
509
+ print (io, " , aggr=" , l. aggr)
510
+ print (io, " )" )
511
+ end
0 commit comments