Skip to content

Commit bb9cb67

Browse files
committed
Add CGConv docstring
1 parent 61137f7 commit bb9cb67

File tree

1 file changed

+59
-0
lines changed

1 file changed

+59
-0
lines changed

GNNLux/src/layers/conv.jl

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,65 @@ function (l::AGNNConv)(g, x::AbstractMatrix, ps, st)
427427
return GNNlib.agnn_conv(m, g, x), st
428428
end
429429

430+
@doc raw"""
431+
CGConv((in, ein) => out, act = identity; residual = false,
432+
use_bias = true, init_weight = glorot_uniform, init_bias = zeros32)
433+
CGConv(in => out, ...)
434+
435+
The crystal graph convolutional layer from the paper
436+
[Crystal Graph Convolutional Neural Networks for an Accurate and
437+
Interpretable Prediction of Material Properties](https://arxiv.org/pdf/1710.10324.pdf).
438+
Performs the operation
439+
440+
```math
441+
\mathbf{x}_i' = \mathbf{x}_i + \sum_{j\in N(i)}\sigma(W_f \mathbf{z}_{ij} + \mathbf{b}_f)\, act(W_s \mathbf{z}_{ij} + \mathbf{b}_s)
442+
```
443+
444+
where ``\mathbf{z}_{ij}`` is the node and edge features concatenation
445+
``[\mathbf{x}_i; \mathbf{x}_j; \mathbf{e}_{j\to i}]``
446+
and ``\sigma`` is the sigmoid function.
447+
The residual ``\mathbf{x}_i`` is added only if `residual=true` and the output size is the same
448+
as the input size.
449+
450+
# Arguments
451+
452+
- `in`: The dimension of input node features.
453+
- `ein`: The dimension of input edge features.
454+
If `ein` is not given, assumes that no edge features are passed as input in the forward pass.
455+
- `out`: The dimension of output node features.
456+
- `act`: Activation function.
457+
- `residual`: Add a residual connection.
458+
- `init_weight`: Weights' initializer. Default `glorot_uniform`.
459+
- `init_bias`: Bias initializer. Default `zeros32`.
460+
- `use_bias`: Add learnable bias. Default `true`.
461+
462+
# Examples
463+
464+
```julia
465+
using GNNLux, Lux, Random
466+
467+
# initialize random number generator
468+
rng = Random.default_rng()
469+
470+
# create random graph
471+
g = rand_graph(rng, 5, 6)
472+
x = rand(rng, Float32, 2, g.num_nodes)
473+
e = rand(rng, Float32, 3, g.num_edges)
474+
475+
l = CGConv((2, 3) => 4, tanh)
476+
477+
# setup layer
478+
ps, st = LuxCore.setup(rng, l)
479+
480+
# forward pass
481+
y, st = l(g, x, e, ps, st) # size: (4, num_nodes)
482+
483+
# No edge features
484+
l = CGConv(2 => 4, tanh)
485+
ps, st = LuxCore.setup(rng, l)
486+
y, st = l(g, x, ps, st) # size: (4, num_nodes)
487+
```
488+
"""
430489
@concrete struct CGConv <: GNNContainerLayer{(:dense_f, :dense_s)}
431490
in_dims::NTuple{2, Int}
432491
out_dims::Int

0 commit comments

Comments
 (0)