Skip to content

Commit 8dae39c

Browse files
export aggregate_neighbors; improve docstrings
1 parent f241024 commit 8dae39c

File tree

5 files changed

+55
-10
lines changed

5 files changed

+55
-10
lines changed

docs/src/api/messagepassing.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Pages = ["messagepassing.md"]
1515

1616
```@docs
1717
apply_edges
18+
aggregate_neighbors
1819
propagate
1920
```
2021

src/GraphNeuralNetworks.jl

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,21 @@ using .GNNGraphs: COO_T, ADJMAT_T, SPARSE_T,
2121

2222
export
2323
# utils
24-
reduce_nodes, reduce_edges,
25-
softmax_nodes, softmax_edges,
26-
broadcast_nodes, broadcast_edges,
24+
reduce_nodes,
25+
reduce_edges,
26+
softmax_nodes,
27+
softmax_edges,
28+
broadcast_nodes,
29+
broadcast_edges,
2730
softmax_edge_neighbors,
2831

2932
# msgpass
30-
apply_edges, propagate,
31-
copy_xj, copy_xi, xi_dot_xj,
33+
apply_edges,
34+
aggregate_neighbors,
35+
propagate,
36+
copy_xj,
37+
copy_xi,
38+
xi_dot_xj,
3239

3340
# layers/basic
3441
GNNLayer,

src/layers/conv.jl

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -824,10 +824,33 @@ function (l::AGNNConv)(g::GNNGraph, x::AbstractMatrix)
824824
end
825825

826826
@doc raw"""
827+
MEGNetConv(ϕe, ϕv; aggr=mean)
827828
MEGNetConv(in => out; aggr=mean)
828829
829830
Convolution from [Graph Networks as a Universal Machine Learning Framework for Molecules and Crystals](https://arxiv.org/pdf/1812.05055.pdf)
830-
paper.
831+
paper. In the forward pass, takes as inputs node features `x` and edge features `e` and returns
832+
updated features `x̄, ē` according to
833+
834+
```math
835+
ē = ϕe(vcat(xi, xj, e))
836+
x̄ = ϕv(vcat(x, \square_{j\in \mathcal{N}(i)} ē_{j\to i}))
837+
```
838+
`aggr` defines the aggregation to be performed.
839+
840+
If the neural networks `ϕe` and `ϕv` are not provided, they will be constructed from
841+
the `in` and `out` arguments instead as multi-layer perceptron with one hidden layer and `relu`
842+
activations.
843+
````
844+
845+
# Examples
846+
847+
```julia
848+
g = rand_graph(10, 30)
849+
x = randn(3, 10)
850+
e = randn(3, 30)
851+
m = MEGNetConv(3 => 3)
852+
x̄, ē = m(g, x, e)
853+
```
831854
"""
832855
struct MEGNetConv <: GNNLayer
833856
ϕe
@@ -837,6 +860,8 @@ end
837860

838861
@functor MEGNetConv
839862

863+
MEGNetConv(ϕe, ϕv; aggr=mean) = MEGNetConv(ϕe, ϕv, aggr)
864+
840865
function MEGNetConv(ch::Pair{Int,Int}; aggr=mean)
841866
nin, nout = ch
842867
ϕe = Chain(Dense(3nin, nout, relu),
@@ -845,7 +870,7 @@ function MEGNetConv(ch::Pair{Int,Int}; aggr=mean)
845870
ϕv = Chain(Dense(nin + nout, nout, relu),
846871
Dense(nout, nout))
847872

848-
MEGNetConv(ϕe, ϕv, aggr)
873+
MEGNetConv(ϕe, ϕv; aggr)
849874
end
850875

851876
function (l::MEGNetConv)(g::GNNGraph)

src/msgpass.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ l = GNNConv(10 => 20)
6161
l(g, x)
6262
```
6363
64-
See also [`apply_edges`](@ref).
64+
See also [`apply_edges`](@ref) and [`aggregate_neighbors`](@ref).
6565
"""
6666
function propagate end
6767

@@ -103,7 +103,7 @@ such tensors.
103103
a batch of edges. The output of `f` has to be an array (or a named tuple of arrays)
104104
with the same batch size.
105105
106-
See also [`propagate`](@ref).
106+
See also [`propagate`](@ref) and [`aggregate_neighbors`](@ref).
107107
"""
108108
function apply_edges end
109109

@@ -125,7 +125,19 @@ _gather(x::Nothing, i) = nothing
125125

126126

127127
## AGGREGATE NEIGHBORS
128+
@doc raw"""
129+
aggregate_neighbors(g::GNNGraph, aggr, m)
130+
131+
Given a graph `g`, edge features `m`, and an aggregation
132+
operator `aggr` (e.g `+, min, max, mean`), returns the new node
133+
features
134+
```math
135+
\mathbf{x}_i = \square_{j \in \mathcal{N}(i)} \mathbf{m}_{j\to i}
136+
```
128137
138+
Neighborhood aggregation is the second step of [`propagate`](@ref),
139+
where it comes after [`apply_edges`](@ref).
140+
"""
129141
function aggregate_neighbors(g::GNNGraph, aggr, m)
130142
s, t = edge_index(g)
131143
return _scatter(aggr, m, t)

test/layers/conv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@
186186
end
187187

188188
@testset "MEGNetConv" begin
189-
l = MEGNetConv(in_channel, edim => out_channel, tanh, aggr=+)
189+
l = MEGNetConv(in_channel => out_channel, tanh, aggr=+)
190190
for g in test_graphs
191191
g = GNNGraph(g, edata=rand(T, in_channel, g.num_edges))
192192
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))

0 commit comments

Comments
 (0)