Skip to content

Commit c18a71f

Browse files
remove update_global
1 parent 1c798dd commit c18a71f

File tree

11 files changed

+56
-69
lines changed

11 files changed

+56
-69
lines changed

docs/make.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
using Flux, NNlib, GraphNeuralNetworks
1+
using Flux, NNlib, GraphNeuralNetworks, LightGraphs, SparseArrays
22
using Documenter
33

44
DocMeta.setdocmeta!(GraphNeuralNetworks, :DocTestSetup, :(using GraphNeuralNetworks); recursive=true)
55

66
makedocs(;
7-
modules=[GraphNeuralNetworks],
7+
modules=[GraphNeuralNetworks, NNlib, Flux, LightGraphs, SparseArrays],
8+
doctest=false, clean=true,
89
sitename = "GraphNeuralNetworks.jl",
910
pages = ["Home" => "index.md",
1011
"GNNGraph" => "gnngraph.md",

docs/src/api/basic.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
```@index
44
Order = [:type, :function]
5-
Pages = ["api/basics.md"]
5+
Modules = [GraphNeuralNetworks, LightGraphs, Flux, NNlib]
6+
Pages = ["api/basics.md"]
67
```
78

8-
```@autodocs
9-
Modules = [GraphNeuralNetworks]
10-
Pages = ["layers/basic.jl"]
11-
Private = false
12-
```
9+
```@docs
10+
GNNLayer
11+
GNNChain
12+
```

docs/src/api/gnngraph.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,8 @@ Modules = [GraphNeuralNetworks]
1010
Pages = ["gnngraph.jl"]
1111
Private = false
1212
```
13+
```@docs
14+
Flux.batch
15+
SparseArrays.blockdiag
16+
LightGraphs.adjacency_matrix
17+
```

docs/src/gnngraph.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ using Flux
6767

6868
gall = Flux.batch([GNNGraph(erdos_renyi(10, 30), ndata=rand(3,10)) for _ in 1:100])
6969

70-
subgraph(gall, 2:3)
70+
getgraph(gall, 2:3)
7171

7272

7373
# DataLoader compatibility

docs/src/messagepassing.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,3 @@
22

33
TODO
44

5-
```@docs
6-
propagate
7-
```

examples/graph_classification_tudataset.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ function eval_loss_accuracy(model, data_loader, device)
2121
y = g.gdata.y
2222
= model(g, g.ndata.X) |> vec
2323
loss += logitbinarycrossentropy(ŷ, y) * n
24-
acc += mean((2 .* .- 1) .* (2 .* y .- 1) .> 0) * n
24+
acc += mean((ŷ .> 0) .== y) * n
2525
ntot += n
2626
end
2727
return (loss = round(loss/ntot, digits=4), acc = round(acc*100/ntot, digits=2))

src/GraphNeuralNetworks.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,18 @@ export
2020
# gnngraph
2121
GNNGraph,
2222
edge_index,
23-
node_features, edge_features, global_features,
23+
node_features, edge_features, graph_features,
2424
adjacency_list, normalized_laplacian, scaled_laplacian,
2525
add_self_loops, remove_self_loops,
26-
subgraph,
26+
getgraph,
2727

2828
# from LightGraphs
2929
adjacency_matrix,
3030
# from SparseArrays
3131
sprand, sparse,
3232

3333
# msgpass
34-
# update, update_edge, update_global, message, propagate,
34+
# update, update_edge, message, propagate,
3535

3636
# layers/basic
3737
GNNLayer,

src/gnngraph.jl

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ const ADJLIST_T = AbstractVector{T} where T <: AbstractVector
1010
const ADJMAT_T = AbstractMatrix
1111
const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T
1212

13-
"""
13+
"""
1414
GNNGraph(data; [graph_type, ndata, edata, gdata, num_nodes, graph_indicator, dir])
1515
GNNGraph(g::GNNGraph; [ndata, edata, gdata])
1616
@@ -416,9 +416,6 @@ end
416416
"""
417417
blockdiag(xs::GNNGraph...)
418418
419-
Batch togheter multiple `GNNGraph`s into a single one
420-
containing the total number of nodes and edges of the original graphs.
421-
422419
Equivalent to [`Flux.batch`](@ref).
423420
"""
424421
function SparseArrays.blockdiag(g1::GNNGraph, gothers::GNNGraph...)
@@ -432,7 +429,7 @@ end
432429
"""
433430
batch(xs::Vector{<:GNNGraph})
434431
435-
Batch togheter multiple `GNNGraph`s into a single one
432+
Batch together multiple `GNNGraph`s into a single one
436433
containing the total number of nodes and edges of the original graphs.
437434
438435
Equivalent to [`SparseArrays.blockdiag`](@ref).
@@ -441,28 +438,28 @@ Flux.batch(xs::Vector{<:GNNGraph}) = blockdiag(xs...)
441438

442439
### LearnBase compatibility
443440
LearnBase.nobs(g::GNNGraph) = g.num_graphs
444-
LearnBase.getobs(g::GNNGraph, i) = subgraph(g, i)[1]
441+
LearnBase.getobs(g::GNNGraph, i) = getgraph(g, i)[1]
445442

446443
# Flux's Dataloader compatibility. Related PR https://github.com/FluxML/Flux.jl/pull/1683
447444
Flux.Data._nobs(g::GNNGraph) = g.num_graphs
448-
Flux.Data._getobs(g::GNNGraph, i) = subgraph(g, i)[1]
445+
Flux.Data._getobs(g::GNNGraph, i) = getgraph(g, i)[1]
449446

450447
#########################
451448
Base.:(==)(g1::GNNGraph, g2::GNNGraph) = all(k -> getfield(g1,k)==getfield(g2,k), fieldnames(typeof(g1)))
452449

453450
"""
454-
subgraph(g::GNNGraph, i)
451+
getgraph(g::GNNGraph, i)
455452
456-
Return the subgraph of `g` induced by those nodes `v`
453+
Return the getgraph of `g` induced by those nodes `v`
457454
for which `g.graph_indicator[v] ∈ i`. In other words, it
458455
extract the component graphs from a batched graph.
459456
460457
It also returns a vector `nodes` mapping the new nodes to the old ones.
461-
The node `i` in the subgraph corresponds to the node `nodes[i]` in `g`.
458+
The node `i` in the getgraph corresponds to the node `nodes[i]` in `g`.
462459
"""
463-
subgraph(g::GNNGraph, i::Int) = subgraph(g::GNNGraph{<:COO_T}, [i])
460+
getgraph(g::GNNGraph, i::Int) = getgraph(g::GNNGraph{<:COO_T}, [i])
464461

465-
function subgraph(g::GNNGraph{<:COO_T}, i::AbstractVector{Int})
462+
function getgraph(g::GNNGraph{<:COO_T}, i::AbstractVector{Int})
466463
if g.graph_indicator === nothing
467464
@assert i == [1]
468465
return g
@@ -517,7 +514,7 @@ function edge_features(g::GNNGraph)
517514
end
518515
end
519516

520-
function global_features(g::GNNGraph)
517+
function graph_features(g::GNNGraph)
521518
if isempty(g.gdata)
522519
return nothing
523520
elseif length(g.gdata) > 1

src/msgpass.jl

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# "Relational inductive biases, deep learning, and graph networks"
33

44
"""
5-
propagate(l, g, aggr, [X, E, U]) -> X′, E′, U
5+
propagate(l, g, aggr, [X, E]) -> X′, E′
66
propagate(l, g, aggr) -> g′
77
88
Perform the sequence of operations implementing the message-passing scheme
@@ -12,11 +12,10 @@ Updates the node, edge, and global features `X`, `E`, and `U` respectively.
1212
The computation involved is the following:
1313
1414
```julia
15-
M = compute_batch_message(l, g, X, E, U)
15+
M = compute_batch_message(l, g, X, E)
1616
M̄ = aggregate_neighbors(l, aggr, g, M)
17-
X′ = update(l, X, M̄, U)
18-
E′ = update_edge(l, E, M, U)
19-
U′ = update_global(l, U, X′, E′)
17+
X′ = update(l, X, M̄)
18+
E′ = update_edge(l, E, M)
2019
```
2120
2221
Custom layers typically define their own [`update`](@ref)
@@ -26,7 +25,7 @@ this method in the forward pass:
2625
```julia
2726
function (l::MyLayer)(g, X)
2827
... some prepocessing if needed ...
29-
propagate(l, g, +, X, E, U)
28+
propagate(l, g, +, X, E)
3029
end
3130
```
3231
@@ -35,24 +34,21 @@ See also [`message`](@ref) and [`update`](@ref).
3534
function propagate end
3635

3736
function propagate(l, g::GNNGraph, aggr)
38-
X, E, U = propagate(l, g, aggr,
39-
node_features(g), edge_features(g), global_features(g))
37+
X, E = propagate(l, g, aggr, node_features(g), edge_features(g))
4038

41-
return GNNGraph(g, ndata=X, edata=E, gdata=U)
39+
return GNNGraph(g, ndata=X, edata=E)
4240
end
4341

44-
function propagate(l, g::GNNGraph, aggr, X, E=nothing, U=nothing)
45-
# TODO consider g.graph_indicator in propagating U
46-
M = compute_batch_message(l, g, X, E, U)
42+
function propagate(l, g::GNNGraph, aggr, X, E=nothing)
43+
M = compute_batch_message(l, g, X, E)
4744
= aggregate_neighbors(l, g, aggr, M)
48-
X′ = update(l, X, M̄, U)
49-
E′ = update_edge(l, E, M, U)
50-
U′ = update_global(l, U, X′, E′)
45+
X′ = update(l, X, M̄)
46+
E′ = update_edge(l, E, M)
5147
return X′, E′, U′
5248
end
5349

5450
"""
55-
message(l, x_i, x_j, [e_ij, u])
51+
message(l, x_i, x_j, [e_ij])
5652
5753
Message function for the message-passing scheme,
5854
returning the message from node `j` to node `i` .
@@ -68,15 +64,14 @@ Custom layer should specialize this method with the desired behavior.
6864
- `l`: A gnn layer.
6965
- `x_i`: Features of the central node `i`.
7066
- `x_j`: Features of the neighbor `j` of node `i`.
71-
- `e_ij`: Features of edge (`i`, `j`).
72-
- `u`: Global features.
67+
- `e_ij`: Features of edge `(i,j)`.
7368
7469
See also [`update`](@ref) and [`propagate`](@ref).
7570
"""
7671
function message end
7772

7873
"""
79-
update(l, x, m̄, [u])
74+
update(l, x, m̄)
8075
8176
Update function for the message-passing scheme,
8277
returning a new set of node features `x′` based on old
@@ -102,15 +97,14 @@ _gather(x::Nothing, i) = nothing
10297

10398
## Step 1.
10499

105-
function compute_batch_message(l, g, X, E, U)
100+
function compute_batch_message(l, g, X, E)
106101
s, t = edge_index(g)
107102
Xi = _gather(X, t)
108103
Xj = _gather(X, s)
109-
M = message(l, Xi, Xj, E, U)
104+
M = message(l, Xi, Xj, E)
110105
return M
111106
end
112107

113-
@inline message(l, x_i, x_j, e_ij, u) = message(l, x_i, x_j, e_ij)
114108
@inline message(l, x_i, x_j, e_ij) = message(l, x_i, x_j)
115109
@inline message(l, x_i, x_j) = x_j
116110

@@ -125,17 +119,10 @@ aggregate_neighbors(l, g, aggr::Nothing, E) = nothing
125119

126120
## Step 3
127121

128-
@inline update(l, x, m̄, u) = update(l, x, m̄)
129122
@inline update(l, x, m̄) =
130123

131124
## Step 4
132125

133-
@inline update_edge(l, E, M, U) = update_edge(l, E, M)
134126
@inline update_edge(l, E, M) = E
135127

136-
## Step 5
137-
138-
@inline update_global(l, U, X, E) = update_global(l, U, X)
139-
@inline update_global(l, U, X) = U
140-
141128
### end steps ###

test/gnngraph.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,13 @@
120120
@test node_features(g123)[:,11:14] node_features(g2)
121121
end
122122

123-
@testset "subgraph" begin
123+
@testset "getgraph" begin
124124
#TODO add graph_type=GRAPH_T
125125
g1 = GNNGraph(random_regular_graph(10,2), ndata=rand(16,10))
126126
g2 = GNNGraph(random_regular_graph(4,2), ndata=rand(16,4))
127127
g3 = GNNGraph(random_regular_graph(7,2), ndata=rand(16,7))
128128
g = Flux.batch([g1, g2, g3])
129-
g2b, nodemap = subgraph(g, 2)
129+
g2b, nodemap = getgraph(g, 2)
130130

131131
s, t = edge_index(g2b)
132132
@test s == edge_index(g2)[1]
@@ -171,11 +171,11 @@
171171
g = Flux.batch([GNNGraph(erdos_renyi(10, 30), ndata=rand(10, n), edata=rand(10, m), gdata=rand(10, 1))
172172
for _ in 1:num_graphs])
173173

174-
@test LearnBase.getobs(g, 3) == subgraph(g, 3)[1]
175-
@test LearnBase.getobs(g, 3:5) == subgraph(g, 3:5)[1]
174+
@test LearnBase.getobs(g, 3) == getgraph(g, 3)[1]
175+
@test LearnBase.getobs(g, 3:5) == getgraph(g, 3:5)[1]
176176
@test LearnBase.nobs(g) == g.num_graphs
177177

178178
d = Flux.Data.DataLoader(g, batchsize = 2, shuffle=false)
179-
@test first(d) == subgraph(g, 1:2)[1]
179+
@test first(d) == getgraph(g, 1:2)[1]
180180
end
181181
end

0 commit comments

Comments
 (0)