Skip to content

Commit afae377

Browse files
Merge pull request #51 from CarloLucibello/cl/cg
add Crystal graph convolution
2 parents de14ded + 45ae490 commit afae377

File tree

4 files changed

+104
-16
lines changed

4 files changed

+104
-16
lines changed

src/GraphNeuralNetworks.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ export
4545
GNNChain,
4646

4747
# layers/conv
48+
CGConv,
4849
ChebConv,
4950
EdgeConv,
5051
GATConv,
@@ -55,7 +56,7 @@ export
5556
NNConv,
5657
ResGatedGraphConv,
5758
SAGEConv,
58-
59+
5960
# layers/pool
6061
GlobalPool,
6162
TopKPool,

src/layers/conv.jl

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,3 +622,91 @@ function Base.show(io::IO, l::ResGatedGraphConv)
622622
l.σ == identity || print(io, ", ", l.σ)
623623
print(io, ")")
624624
end
625+
626+
627+
628+
@doc raw"""
629+
CGConv((nin, ein) => nout, f, act=identity; bias=true, init=glorot_uniform, residual=false)
630+
631+
The crystal graph convolutional layer from the paper
632+
[Crystal Graph Convolutional Neural Networks for an Accurate and
633+
Interpretable Prediction of Material Properties](https://arxiv.org/pdf/1710.10324.pdf).
634+
Performs the operation
635+
636+
```math
637+
\mathbf{x}_i' = \mathbf{x}_i + \sum_{j\in N(i)}\sigma(W_f \mathbf{z}_{ij} + \mathbf{b}_f)\, act(W_s \mathbf{z}_{ij} + \mathbf{b}_s)
638+
```
639+
640+
where ``z_ij`` is the node and edge features concatenation
641+
``[\\mathbf{x}_i \\,\\|\\, \\mathbf{x}_j \\,\\|\\, \\mathbf{e}_{j\\to i}]``
642+
and ``\\sigma`` is the sigmoid function.
643+
The residual ``\\mathbf{x}_i`` is added only if `residual=true` and the output size is the same
644+
as the input size.
645+
646+
# Arguments
647+
648+
- `nin`: The dimension of input node features.
649+
- `nout`: The dimension of input edge features.
650+
- `out`: The dimension of output node features.
651+
- `act`: Activation function.
652+
- `bias`: Add learnable bias.
653+
- `init`: Weights' initializer.
654+
- `residual`: Add a residual connection.
655+
656+
# Usage
657+
658+
```julia
659+
x = rand(Float32, 2, g.num_nodes)
660+
e = rand(Float32, 3, g.num_edges)
661+
662+
l = GCNConv((2,3) => 4, tanh)
663+
664+
y = l(g, x, e) # size: (4, num_nodes)
665+
```
666+
"""
667+
struct CGConv <: GNNLayer
668+
ch
669+
dense_f::Dense
670+
dense_s::Dense
671+
residual::Bool
672+
end
673+
674+
@functor CGConv
675+
676+
CGConv(nin::Int, ein::Int, out::Int, args...; kws...) = CGConv((nin, ein) => out, args...; kws...)
677+
678+
function CGConv(ch::Pair{NTuple{2,Int},Int}, act=identity; residual=false, bias=true, init=glorot_uniform)
679+
(nin, ein), out = ch
680+
dense_f = Dense(2nin+ein, out, sigmoid; bias, init)
681+
dense_s = Dense(2nin+ein, out, act; bias, init)
682+
return CGConv(ch, dense_f, dense_s, residual)
683+
end
684+
685+
function (l::CGConv)(g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix)
686+
check_num_nodes(g, x)
687+
check_num_edges(g, e)
688+
689+
function message(xi, xj, e)
690+
z = vcat(xi, xj, e)
691+
return l.dense_f(z) .* l.dense_s(z)
692+
end
693+
694+
m = propagate(message, g, +, xi=x, xj=x, e=e)
695+
if l.residual
696+
if size(x, 1) == size(m, 1)
697+
m += x
698+
else
699+
@warn "number of output features different from number of input features, residual not applyed."
700+
end
701+
end
702+
return m
703+
end
704+
705+
(l::CGConv)(g::GNNGraph) = GNNGraph(g, ndata=l(g, node_features(g), edge_features(g)))
706+
707+
function Base.show(io::IO, l::CGConv)
708+
print(io, "CGConv($(l.ch)")
709+
l.dense_s.σ == identity || print(io, ", ", l.dense_s.σ)
710+
print(io, ", residual=$(l.residual)")
711+
print(io, ")")
712+
end

src/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
function check_num_nodes(g::GNNGraph, x::AbstractArray)
22
@assert g.num_nodes == size(x, ndims(x))
33
end
4+
function check_num_edges(g::GNNGraph, e::AbstractArray)
5+
@assert g.num_edges == size(e, ndims(e))
6+
end
47

58
sort_edge_index(eindex::Tuple) = sort_edge_index(eindex...)
69

test/layers/conv.jl

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,7 @@
124124
edim = 10
125125
nn = Dense(edim, out_channel * in_channel)
126126

127-
l = NNConv(in_channel => out_channel, nn)
128-
for g in test_graphs
129-
g = GNNGraph(g, edata=rand(T, edim, g.num_edges))
130-
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
131-
end
132-
133-
l = NNConv(in_channel => out_channel, nn, tanh, bias=false, aggr=mean)
127+
l = NNConv(in_channel => out_channel, nn, tanh, bias=true, aggr=+)
134128
for g in test_graphs
135129
g = GNNGraph(g, edata=rand(T, edim, g.num_edges))
136130
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
@@ -140,10 +134,7 @@
140134
@testset "SAGEConv" begin
141135
l = SAGEConv(in_channel => out_channel)
142136
@test l.aggr == mean
143-
for g in test_graphs
144-
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
145-
end
146-
137+
147138
l = SAGEConv(in_channel => out_channel, tanh, bias=false, aggr=+)
148139
for g in test_graphs
149140
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
@@ -152,14 +143,19 @@
152143

153144

154145
@testset "ResGatedGraphConv" begin
155-
l = ResGatedGraphConv(in_channel => out_channel)
146+
l = ResGatedGraphConv(in_channel => out_channel, tanh, bias=true)
156147
for g in test_graphs
157-
test_layer(l, g, rtol=1e-5,)
148+
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
158149
end
150+
end
159151

160-
l = ResGatedGraphConv(in_channel => out_channel, tanh, bias=false)
152+
153+
@testset "CGConv" begin
154+
edim = 10
155+
l = CGConv((in_channel, edim) => out_channel, tanh, residual=false, bias=true)
161156
for g in test_graphs
162-
test_layer(l, g, rtol=1e-5,)
157+
g = GNNGraph(g, edata=rand(T, edim, g.num_edges))
158+
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
163159
end
164160
end
165161
end

0 commit comments

Comments
 (0)