Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 4 additions & 13 deletions GNNlib/src/msgpass.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
"""
propagate(fmsg, g, aggr [layer]; [xi, xj, e])
propagate(fmsg, g, aggr, [layer,] xi, xj, e=nothing)
propagate(fmsg, g, aggr; [xi, xj, e])
propagate(fmsg, g, aggr xi, xj, e=nothing)

Performs message passing on graph `g`. Takes care of materializing the node features on each edge,
applying the message function `fmsg`, and returning an aggregated message ``\\bar{\\mathbf{m}}``
(depending on the return value of `fmsg`, an array or a named tuple of
arrays with last dimension's size `g.num_nodes`).

If also a [`GNNLayer`](@ref) `layer` is provided, it will be passed to `fmsg`
as a first argument.

It can be decomposed in two steps:

```julia
Expand All @@ -35,10 +32,8 @@ providing as input `f` a closure.
with the same batch size. If also `layer` is passed to propagate,
the signature of `fmsg` has to be `fmsg(layer, xi, xj, e)`
instead of `fmsg(xi, xj, e)`.
- `layer`: A [`GNNLayer`](@ref). If provided it will be passed to `fmsg` as a first argument.
- `aggr`: Neighborhood aggregation operator. Use `+`, `mean`, `max`, or `min`.


# Examples

```julia
Expand Down Expand Up @@ -86,8 +81,8 @@ end
## APPLY EDGES

"""
apply_edges(fmsg, g, [layer]; [xi, xj, e])
apply_edges(fmsg, g, [layer,] xi, xj, e=nothing)
apply_edges(fmsg, g; [xi, xj, e])
apply_edges(fmsg, g, xi, xj, e=nothing)

Returns the message from node `j` to node `i` applying
the message function `fmsg` on the edges in graph `g`.
Expand All @@ -99,9 +94,6 @@ The function `fmsg` operates on batches of edges, therefore
`xi`, `xj`, and `e` are tensors whose last dimension
is the batch size, or can be named tuples of
such tensors.

If also a [`GNNLayer`](@ref) `layer` is provided, it will be passed to `fmsg`
as a first argument.

# Arguments

Expand All @@ -117,7 +109,6 @@ as a first argument.
with the same batch size. If also `layer` is passed to propagate,
the signature of `fmsg` has to be `fmsg(layer, xi, xj, e)`
instead of `fmsg(xi, xj, e)`.
- `layer`: A [`GNNLayer`](@ref). If provided it will be passed to `fmsg` as a first argument.

See also [`propagate`](@ref) and [`aggregate_neighbors`](@ref).
"""
Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48"
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ prettyurls = get(ENV, "CI", nothing) == "true"
mathengine = MathJax3()

makedocs(;
modules = [GraphNeuralNetworks, GNNGraphs],
modules = [GraphNeuralNetworks, GNNGraphs, GNNlib],
doctest = false,
clean = true,
plugins = [interlinks],
Expand Down
20 changes: 10 additions & 10 deletions docs/src/api/messagepassing.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@ Pages = ["messagepassing.md"]
## Interface

```@docs
apply_edges
aggregate_neighbors
propagate
GNNlib.apply_edges
GNNlib.aggregate_neighbors
GNNlib.propagate
```

## Built-in message functions

```@docs
copy_xi
copy_xj
xi_dot_xj
xi_sub_xj
xj_sub_xi
e_mul_xj
w_mul_xj
GNNlib.copy_xi
GNNlib.copy_xj
GNNlib.xi_dot_xj
GNNlib.xi_sub_xj
GNNlib.xj_sub_xi
GNNlib.e_mul_xj
GNNlib.w_mul_xj
```
Loading