@@ -427,6 +427,65 @@ function (l::AGNNConv)(g, x::AbstractMatrix, ps, st)
427427 return GNNlib. agnn_conv (m, g, x), st
428428end
429429
430+ @doc raw """
431+ CGConv((in, ein) => out, act = identity; residual = false,
432+ use_bias = true, init_weight = glorot_uniform, init_bias = zeros32)
433+ CGConv(in => out, ...)
434+
435+ The crystal graph convolutional layer from the paper
436+ [Crystal Graph Convolutional Neural Networks for an Accurate and
437+ Interpretable Prediction of Material Properties](https://arxiv.org/pdf/1710.10324.pdf).
438+ Performs the operation
439+
440+ ```math
441+ \m athbf{x}_i' = \m athbf{x}_i + \s um_{j\i n N(i)}\s igma(W_f \m athbf{z}_{ij} + \m athbf{b}_f)\, act(W_s \m athbf{z}_{ij} + \m athbf{b}_s)
442+ ```
443+
444+ where ``\m athbf{z}_{ij}`` is the node and edge features concatenation
445+ ``[\m athbf{x}_i; \m athbf{x}_j; \m athbf{e}_{j\t o i}]``
446+ and ``\s igma`` is the sigmoid function.
447+ The residual ``\m athbf{x}_i`` is added only if `residual=true` and the output size is the same
448+ as the input size.
449+
450+ # Arguments
451+
452+ - `in`: The dimension of input node features.
453+ - `ein`: The dimension of input edge features.
454+ If `ein` is not given, assumes that no edge features are passed as input in the forward pass.
455+ - `out`: The dimension of output node features.
456+ - `act`: Activation function.
457+ - `residual`: Add a residual connection.
458+ - `init_weight`: Weights' initializer. Default `glorot_uniform`.
459+ - `init_bias`: Bias initializer. Default `zeros32`.
460+ - `use_bias`: Add learnable bias. Default `true`.
461+
462+ # Examples
463+
464+ ```julia
465+ using GNNLux, Lux, Random
466+
467+ # initialize random number generator
468+ rng = Random.default_rng()
469+
470+ # create random graph
471+ g = rand_graph(rng, 5, 6)
472+ x = rand(rng, Float32, 2, g.num_nodes)
473+ e = rand(rng, Float32, 3, g.num_edges)
474+
475+ l = CGConv((2, 3) => 4, tanh)
476+
477+ # setup layer
478+ ps, st = LuxCore.setup(rng, l)
479+
480+ # forward pass
481+ y, st = l(g, x, e, ps, st) # size: (4, num_nodes)
482+
483+ # No edge features
484+ l = CGConv(2 => 4, tanh)
485+ ps, st = LuxCore.setup(rng, l)
486+ y, st = l(g, x, ps, st) # size: (4, num_nodes)
487+ ```
488+ """
430489@concrete struct CGConv <: GNNContainerLayer{(:dense_f, :dense_s)}
431490 in_dims:: NTuple{2, Int}
432491 out_dims:: Int
0 commit comments