Skip to content

Commit eeb450a

Browse files
rework tests; add NNConv
1 parent d678551 commit eeb450a

File tree

5 files changed

+236
-180
lines changed

5 files changed

+236
-180
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@ NNlibCUDA = "0.1"
3333
julia = "1.6"
3434

3535
[extras]
36+
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
37+
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
3638
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3739
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3840

3941
[targets]
40-
test = ["Test", "Zygote"]
42+
test = ["Test", "Zygote", "FiniteDifferences", "ChainRulesTestUtils"]

src/layers/conv.jl

Lines changed: 73 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -356,20 +356,20 @@ end
356356

357357

358358
@doc raw"""
359-
EdgeConv(f; aggr=max)
359+
EdgeConv(nn; aggr=max)
360360
361361
Edge convolutional layer from paper [Dynamic Graph CNN for Learning on Point Clouds](https://arxiv.org/abs/1801.07829).
362362
363363
Performs the operation
364364
```math
365-
\mathbf{x}_i' = \square_{j \in N(i)} f(\mathbf{x}_i || \mathbf{x}_j - \mathbf{x}_i)
365+
\mathbf{x}_i' = \square_{j \in N(i)} nn(\mathbf{x}_i || \mathbf{x}_j - \mathbf{x}_i)
366366
```
367367
368-
where `f` typically denotes a learnable function, e.g. a linear layer or a multi-layer perceptron.
368+
where `nn` generally denotes a learnable function, e.g. a linear layer or a multi-layer perceptron.
369369
370370
# Arguments
371371
372-
- `f`: A (possibly learnable) function acting on edge features.
372+
- `nn`: A (possibly learnable) function acting on edge features.
373373
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
374374
"""
375375
struct EdgeConv <: GNNLayer
@@ -405,13 +405,13 @@ Graph Isomorphism convolutional layer from paper [How Powerful are Graph Neural
405405
406406
407407
```math
408-
\mathbf{x}_i' = f\left((1 + \epsilon) \mathbf{x}_i + \sum_{j \in N(i)} \mathbf{x}_j \right)
408+
\mathbf{x}_i' = f_\Theta\left((1 + \epsilon) \mathbf{x}_i + \sum_{j \in N(i)} \mathbf{x}_j \right)
409409
```
410-
where `f` typically denotes a learnable function, e.g. a linear layer or a multi-layer perceptron.
410+
where ``f_\Theta`` typically denotes a learnable function, e.g. a linear layer or a multi-layer perceptron.
411411
412412
# Arguments
413413
414-
- `f`: A (possibly learnable) function acting on node features.
414+
- ``f``: A (possibly learnable) function acting on node features.
415415
- `eps`: Weighting factor.
416416
"""
417417
struct GINConv{R<:Real} <: GNNLayer
@@ -434,3 +434,69 @@ function (l::GINConv)(g::GNNGraph, X::AbstractMatrix)
434434
X, _ = propagate(l, g, +, X)
435435
X
436436
end
437+
438+
439+
@doc raw"""
440+
NNConv(in => out, σ=identity; aggr=+, bias=true, init=glorot_uniform)
441+
442+
The continuous kernel-based convolutional operator from the
443+
[Neural Message Passing for Quantum Chemistry](https://arxiv.org/abs/1704.01212) paper.
444+
This convolution is also known as the edge-conditioned convolution from the
445+
[Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs](https://arxiv.org/abs/1704.02901) paper.
446+
447+
Performs the operation
448+
449+
```math
450+
\mathbf{x}_i' = W x_i + \square_{j \in N(i)} f_\Theta(\mathbf{e}_{j\to i})\,\mathbf{x}_j
451+
```
452+
453+
where ``f_\Theta`` denotes a learnable function (e.g. a linear layer or a multi-layer perceptron).
454+
Given an input of batched edge features `e` of size `(num_edge_features, num_edges)`,
455+
the function `f` will return an batched matrices array whose size is `(out, in, num_edges)`.
456+
For convenience, also functions returning a single `(out*in, num_edges)` matrix are allowed.
457+
458+
# Arguments
459+
460+
- `in`: The dimension of input features.
461+
- `out`: The dimension of output features.
462+
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
463+
- `σ`: Activation function.
464+
- `bias`: Add learnable bias.
465+
- `init`: Weights' initializer.
466+
"""
467+
struct NNConv <: GNNLayer
468+
weight
469+
bias
470+
nn
471+
aggr
472+
end
473+
474+
@functor NNConv
475+
476+
function NNConv(ch::Pair{Int,Int}, σ=identity; aggr=+, bias=true, init=glorot_uniform)
477+
in, out = ch
478+
W = init(out, in)
479+
b = Flux.create_bias(W, bias, out)
480+
return NNConv(W, b, nn, aggr)
481+
end
482+
483+
function compute_message(l::NNConv, x_i, x_j, e_ij)
484+
nin, nedges = size(x_i)
485+
W = reshape(l.nn(e_ij), (:, nin, nedges))
486+
return NNlib.batched_mul(W, x_j)
487+
end
488+
489+
update_node(l::NNConv, m, x) = l.weight*x + m
490+
491+
function (l::NNConv)(g::GNNGraph, x::AbstractMatrix, e)
492+
check_num_nodes(g, X)
493+
x, _ = propagate(l, g, l.aggr, x, e)
494+
return l.σ.(x + l.bias)
495+
end
496+
497+
function Base.show(io::IO, l::NNConv)
498+
out, in = size(l.weight)
499+
print(io, "NNConv( $in => $out")
500+
print(io, ", aggr=", l.aggr)
501+
print(io, ")")
502+
end

0 commit comments

Comments
 (0)