Skip to content

Commit 46676d5

Browse files
HeteroGraphConv documentation (#312)
* batching * HeteroGraphConv docs * complete docs * more tests * fix test
1 parent d952fc5 commit 46676d5

File tree

5 files changed

+110
-6
lines changed

5 files changed

+110
-6
lines changed

docs/src/api/heterograph.md

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,21 @@
1-
# GNNHeteroGraph
1+
# Hetereogeneous Graphs
22

3-
Documentation page for the graph type `GNNHeteroGraph` and related methods representing heterogeneous graphs,
4-
where nodes and edges can have different types.
3+
4+
## GNNHeteroGraph
5+
Documentation page for the type `GNNHeteroGraph` representing heterogeneous graphs, where nodes and edges can have different types.
56

67

78
```@autodocs
89
Modules = [GraphNeuralNetworks.GNNGraphs]
910
Pages = ["gnnheterograph.jl"]
1011
Private = false
1112
```
13+
14+
## Heterogeneous Graph Convolutions
15+
16+
Heterogeneous graph convolutions are implemented in the type [`HeteroGraphConv`](@ref).
17+
`HeteroGraphConv` relies on standard graph convolutional layers to perform message passing on the different relations. See the table at [this page](https://carlolucibello.github.io/GraphNeuralNetworks.jl/dev/api/conv/) for the supported layers.
18+
19+
```@docs
20+
HeteroGraphConv
21+
```

docs/src/gnngraph.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,15 +189,15 @@ graph_indicator(gall)
189189
## DataLoader and mini-batch iteration
190190

191191
While constructing a batched graph and passing it to the `DataLoader` is always
192-
an option for mini-batch iteration, the recommended way is
193-
to pass an array of graphs directly:
192+
an option for mini-batch iteration, the recommended way for better performance is
193+
to pass an array of graphs directly and set the `collate` option to `true`:
194194

195195
```julia
196196
using Flux: DataLoader
197197

198198
data = [rand_graph(10, 30, ndata=rand(Float32, 3, 10)) for _ in 1:320]
199199

200-
train_loader = DataLoader(data, batchsize=16, shuffle=true)
200+
train_loader = DataLoader(data, batchsize=16, shuffle=true, collate=true)
201201

202202
for g in train_loader
203203
@assert g.num_graphs == 16

docs/src/heterograph.md

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,4 +94,37 @@ GNNHeteroGraph:
9494
(:user, :rate, :movie) => DataStore(e = [64×4 Matrix{Float32}])
9595
```
9696

97+
## Batching
98+
Similarly to graphs, also heterographs can be batched together.
99+
```julia-repl
100+
```julia
101+
julia> gs = [rand_bipartite_heterograph((5, 10), 20) for _ in 1:32];
102+
103+
julia> Flux.batch(gs)
104+
GNNHeteroGraph:
105+
num_nodes: Dict(:A => 160, :B => 320)
106+
num_edges: Dict((:A, :to, :B) => 640, (:B, :to, :A) => 640)
107+
num_graphs: 32
108+
```
109+
Batching is automatically performed by the [`DataLoader`](@ref) iterator
110+
when the `collate` option is set to `true`.
111+
```julia-repl
112+
using Flux: DataLoader
113+
114+
data = [rand_bipartite_heterograph((5, 10), 20,
115+
ndata=Dict(:A=>rand(Float32, 3, 5)))
116+
for _ in 1:320];
117+
118+
train_loader = DataLoader(data, batchsize=16, shuffle=true, collate=true)
119+
120+
for g in train_loader
121+
@assert g.num_graphs == 16
122+
@assert g.num_nodes[:A] == 80
123+
@assert size(g.ndata[:A].x) == (3, 80)
124+
# ...
125+
end
126+
```
127+
128+
## Graph convolutions on heterographs
97129

130+
See [`HeteroGraphConv`](@ref) for how to perform convolutions on heterogenous graphs.

src/layers/heteroconv.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,41 @@
1+
@doc raw"""
2+
HeteroGraphConv(itr; aggr = +)
3+
HeteroGraphConv(pairs...; aggr = +)
4+
5+
A convolutional layer for heterogeneous graphs.
6+
7+
The `itr` argument is an iterator of `pairs` of the form `edge_t => layer`, where `edge_t` is a
8+
3-tuple of the form `(src_node_type, edge_type, dst_node_type)`, and `layer` is a
9+
convolutional layers for homogeneous graphs.
10+
11+
Each convolution is applied to the corresponding relation.
12+
Since a node type can be involved in multiple relations, the single convolution outputs
13+
have to be aggregated using the `aggr` function. The default is to sum the outputs.
14+
15+
# Forward Arguments
16+
17+
* `g::GNNHeteroGraph`: The input graph.
18+
* `x::Union{NamedTuple,Dict}`: The input node features. The keys are node types and the
19+
values are node feature tensors.
20+
21+
# Examples
22+
23+
```julia-repl
24+
julia> g = rand_bipartite_heterograph((10, 15), 20)
25+
GNNHeteroGraph:
26+
num_nodes: Dict(:A => 10, :B => 15)
27+
num_edges: Dict((:A, :to, :B) => 20, (:B, :to, :A) => 20)
28+
29+
julia> x = (A = rand(Float32, 64, 10), B = rand(Float32, 64, 15));
30+
31+
julia> layer = HeteroGraphConv((:A, :to, :B) => GraphConv(64 => 32, relu),
32+
(:B, :to, :A) => GraphConv(64 => 32, relu));
33+
34+
julia> y = layer(g, x); # output is a named tuple
35+
36+
julia> size(y.A) == (32, 10) && size(y.B) == (32, 15)
37+
true
38+
"""
139
struct HeteroGraphConv
240
etypes::Vector{EType}
341
layers::Vector{<:GNNLayer}
@@ -6,6 +44,9 @@ end
644

745
Flux.@functor HeteroGraphConv
846

47+
HeteroGraphConv(itr::Dict; aggr = +) = HeteroGraphConv(pairs(itr); aggr)
48+
HeteroGraphConv(itr::Pair...; aggr = +) = HeteroGraphConv(itr; aggr)
49+
950
function HeteroGraphConv(itr; aggr = +)
1051
etypes = [k[1] for k in itr]
1152
layers = [k[2] for k in itr]
@@ -37,3 +78,17 @@ function _reduceby_node_t(aggr, outs, ntypes)
3778
vals = [_reduce(node_t) for node_t in ntypes]
3879
return NamedTuple{tuple(ntypes...)}(vals)
3980
end
81+
82+
function Base.show(io::IO, hgc::HeteroGraphConv)
83+
if get(io, :compact, false)
84+
print(io, "HeteroGraphConv(aggr=$(hgc.aggr))")
85+
else
86+
println(io, "HeteroGraphConv(aggr=$(hgc.aggr)):")
87+
for (i, (et,layer)) in enumerate(zip(hgc.etypes, hgc.layers))
88+
print(io, " $(et => layer)")
89+
if i < length(hgc.etypes)
90+
print(io, "\n")
91+
end
92+
end
93+
end
94+
end

test/layers/heteroconv.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,10 @@
2727
@test dx[:A] ndx[:A] rtol=1e-4
2828
@test dx[:B] ndx[:B] rtol=1e-4
2929
end
30+
31+
@testset "Constructor from pairs" begin
32+
layer = HeteroGraphConv((:A, :to, :B) => GraphConv(64 => 32, relu),
33+
(:B, :to, :A) => GraphConv(64 => 32, relu));
34+
@test length(layer.etypes) == 2
35+
end
3036
end

0 commit comments

Comments
 (0)