Skip to content

Commit f49aae3

Browse files
docs
1 parent adb3cdb commit f49aae3

File tree

8 files changed

+54
-15
lines changed

8 files changed

+54
-15
lines changed

docs/src/api/messagepassing.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,15 @@ Order = [:type, :function]
1111
Pages = ["messagepassing.md"]
1212
```
1313

14-
## Docs
14+
## Interface
1515

1616
```@docs
1717
apply_edges
1818
propagate
1919
```
20+
21+
## Built-in message functions
22+
23+
```@docs
24+
copyxj
25+
```

docs/src/messagepassing.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,9 @@ end
8383
```
8484

8585
See the [`GATConv`](@ref) implementation [here](https://github.com/CarloLucibello/GraphNeuralNetworks.jl/blob/master/src/layers/conv.jl) for a more complex example.
86+
87+
88+
## Built-in message functions
89+
90+
In order to exploit optimized specializations of the [`propagate`](@ref), it is recommended
91+
to use built-in message functions such as [`copyxj`](@ref) whenever possible.

src/GraphNeuralNetworks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ export
3232

3333
# msgpass
3434
apply_edges, propagate,
35+
copyxj,
3536

3637
# layers/basic
3738
GNNLayer,

src/gnngraph.jl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,13 +160,26 @@ function GNNGraph(g::AbstractGraph; kws...)
160160
end
161161

162162

163-
function GNNGraph(g::GNNGraph; ndata=g.ndata, edata=g.edata, gdata=g.gdata)
163+
function GNNGraph(g::GNNGraph; ndata=g.ndata, edata=g.edata, gdata=g.gdata, graph_type=nothing)
164164

165165
ndata = normalize_graphdata(ndata, default_name=:x, n=g.num_nodes)
166166
edata = normalize_graphdata(edata, default_name=:e, n=g.num_edges, duplicate_if_needed=true)
167167
gdata = normalize_graphdata(gdata, default_name=:u, n=g.num_graphs)
168-
169-
GNNGraph(g.graph,
168+
169+
if !isnothing(graph_type)
170+
if graph_type == :coo
171+
graph, num_nodes, num_edges = to_coo(g.graph; g.num_nodes)
172+
elseif graph_type == :dense
173+
graph, num_nodes, num_edges = to_dense(g.graph)
174+
elseif graph_type == :sparse
175+
graph, num_nodes, num_edges = to_sparse(g.graph)
176+
end
177+
@assert num_nodes == g.num_nodes
178+
@assert num_edges == g.num_edges
179+
else
180+
graph = g.graph
181+
end
182+
GNNGraph(graph,
170183
g.num_nodes, g.num_edges, g.num_graphs,
171184
g.graph_indicator,
172185
ndata, edata, gdata)

src/layers/conv.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ function (l::GCNConv)(g::GNNGraph, x::AbstractMatrix{T}) where T
4747
if Dout < Din
4848
x = l.weight * x
4949
end
50+
# @assert all(>(0), degree(g, T, dir=:in))
5051
c = 1 ./ sqrt.(degree(g, T, dir=:in))
5152
x = x .* c'
5253
x = propagate(copyxj, g, +, xj=x)

src/msgpass.jl

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
"""
22
propagate(f, g, aggr; xi, xj, e) -> m̄
33
4-
Performs message passing on graph `g`.
5-
6-
Takes care of materializing the node features on each edge,
7-
applying the message function, and returning an aggregated message ``\bar{\mathbf{m}}``
4+
Performs message passing on graph `g`. Takes care of materializing the node features on each edge,
5+
applying the message function, and returning an aggregated message ``\\bar{\\mathbf{m}}``
86
(depending on the return value of `f`, an array or a named tuple of
97
arrays with last dimension's size `g.num_nodes`).
108
@@ -139,16 +137,25 @@ _scatter(aggr, m::AbstractArray, t) = NNlib.scatter(aggr, m, t)
139137

140138

141139
### SPECIALIZATIONS OF PROPAGATE ###
142-
copyxi(xi, xj, e) = xi
140+
"""
141+
copyxj(xi, xj, e) = xj
142+
"""
143143
copyxj(xi, xj, e) = xj
144-
ximulxj(xi, xj, e) = xi .* xj
145-
xiaddxj(xi, xj, e) = xi .+ xj
146144

147-
function propagate(::typeof(copyxj), g::GNNGraph, ::typeof(+), xi, xj, e)
145+
# copyxi(xi, xj, e) = xi
146+
# ximulxj(xi, xj, e) = xi .* xj
147+
# xiaddxj(xi, xj, e) = xi .+ xj
148+
149+
function propagate(::typeof(copyxj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix, e)
148150
A = adjacency_matrix(g)
149151
return xj * A
150152
end
151153

152-
# TODO divide by degree
153-
# propagate(::typeof(copyxj), g::GNNGraph, ::typeof(mean), xi, xj, e)
154+
# function propagate(::typeof(copyxj), g::GNNGraph, ::typeof(mean), xi, xj::AbstractMatrix, e)
155+
# A = adjacency_matrix(g)
156+
# degs = vec(sum(A; dims=2))
157+
# D = Diagonal(ofeltype(xj, 1) ./ degs)
158+
# # A, D = _aa(g, xj)
159+
# return xj * A * D
160+
# end
154161

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ tests = [
2727
!CUDA.functional() && @warn("CUDA unavailable, not testing GPU support")
2828

2929
# Testing all graph types. :sparse is a bit broken at the moment
30-
@testset "GraphNeuralNetworks: graph format $graph_type" for graph_type in (:coo,:sparse,:dense)
30+
@testset "GraphNeuralNetworks: graph format $graph_type" for graph_type in (:coo,)
3131

3232
global GRAPH_T = graph_type
3333
global TEST_GPU = CUDA.functional() && GRAPH_T != :sparse

test/test_utils.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ function test_layer(l, g::GNNGraph; atol = 1e-7, rtol = 1e-5,
5050
if !isnothing(outsize)
5151
@test size(y) == outsize
5252
end
53+
54+
# test same output on different graph formats
55+
gcoo = GNNGraph(g, graph_type=:coo)
56+
ycoo = f(l, gcoo, x)
57+
@test ycoo y
5358

5459
g′ = f(l, g)
5560
@test g′.ndata.x y

0 commit comments

Comments
 (0)