Skip to content

Commit 692d2b2

Browse files
HeteroGraphConv implementation (#300)
* [wip] heteroconv implementation * now working * lower gattest precision * fix tests * fix tests
1 parent 12fee79 commit 692d2b2

21 files changed

+260
-68
lines changed

docs/src/api/conv.md

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,29 +10,30 @@ Some of the most commonly used layers are the [`GCNConv`](@ref) and the [`GATv2C
1010

1111
The table below lists all graph convolutional layers implemented in the *GraphNeuralNetworks.jl*. It also highlights the presence of some additional capabilities with respect to basic message passing:
1212
- *Sparse Ops*: implements message passing as multiplication by sparse adjacency matrix instead of the gather/scatter mechanism. This can lead to better cpu performances but it is not supported on gpu yet.
13-
- *Edge Weights*: supports scalar weights (or equivalently scalar features) on edges.
13+
- *Edge Weight*: supports scalar weights (or equivalently scalar features) on edges.
1414
- *Edge Features*: supports feature vectors on edges.
15+
- *Heterograph*: supports heterogeneous graphs (see [`GNNHeteroGraphs`](@ref)).
1516

16-
| Layer |Sparse Ops|Edge Weight|Edge Features|
17-
| :-------- | :---: |:---: |:---: |
18-
| [`AGNNConv`](@ref) | | ||
19-
| [`CGConv`](@ref) | | ||
20-
| [`ChebConv`](@ref) | | | |
21-
| [`EGNNConv`](@ref) | | ||
22-
| [`EdgeConv`](@ref) | | | |
23-
| [`GATConv`](@ref) | | ||
24-
| [`GATv2Conv`](@ref) | | ||
25-
| [`GatedGraphConv`](@ref) || | |
26-
| [`GCNConv`](@ref) ||| |
27-
| [`GINConv`](@ref) || | |
28-
| [`GMMConv`](@ref) | | ||
29-
| [`GraphConv`](@ref) || | |
30-
| [`MEGNetConv`](@ref) | | ||
31-
| [`NNConv`](@ref) | | ||
32-
| [`ResGatedGraphConv`](@ref) | | | |
33-
| [`SAGEConv`](@ref) || | |
34-
| [`SGConv`](@ref) || | |
35-
| [`TransformerConv`](@ref) | | ||
17+
| Layer |Sparse Ops|Edge Weight|Edge Features| Heterograph |
18+
| :-------- | :---: |:---: |:---: | :---: |
19+
| [`AGNNConv`](@ref) | | || |
20+
| [`CGConv`](@ref) | | || |
21+
| [`ChebConv`](@ref) | | | | |
22+
| [`EGNNConv`](@ref) | | || |
23+
| [`EdgeConv`](@ref) | | | | |
24+
| [`GATConv`](@ref) | | || |
25+
| [`GATv2Conv`](@ref) | | || |
26+
| [`GatedGraphConv`](@ref) || | | |
27+
| [`GCNConv`](@ref) ||| | |
28+
| [`GINConv`](@ref) || | | |
29+
| [`GMMConv`](@ref) | | || |
30+
| [`GraphConv`](@ref) || | ||
31+
| [`MEGNetConv`](@ref) | | || |
32+
| [`NNConv`](@ref) | | || |
33+
| [`ResGatedGraphConv`](@ref) | | | | |
34+
| [`SAGEConv`](@ref) || | | |
35+
| [`SGConv`](@ref) || | | |
36+
| [`TransformerConv`](@ref) | | || |
3637

3738

3839
## Docs

docs/src/api/heterograph.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
# HeteroGNNGraph
1+
# GNNHeteroGraph
22

3-
Documentation page for the graph type `HeteroGNNGraph` and related methods representing heterogeneous graphs,
3+
Documentation page for the graph type `GNNHeteroGraph` and related methods representing heterogeneous graphs,
44
where nodes and edges can have different types.
55

66

perf/neural_ode_mnist.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,9 @@ model = Chain(Flux.flatten,
3535
Dense(nhidden, nout)) |> device
3636

3737
# # Training
38-
# ## Model Parameters
39-
ps = Flux.params(model);
4038

4139
# ## Optimizer
42-
opt = Adam(0.01)
40+
opt = Flux.setup(Adam(0.01), model)
4341

4442
function eval_loss_accuracy(X, y)
4543
= model(X)
@@ -50,10 +48,10 @@ end
5048

5149
# ## Training Loop
5250
for epoch in 1:epochs
53-
gs = gradient(ps) do
51+
grad = gradient(model) do model
5452
= model(X)
5553
logitcrossentropy(ŷ, y)
5654
end
57-
Flux.Optimise.update!(opt, ps, gs)
55+
Flux.update!(opt, model, grad[1])
5856
@show eval_loss_accuracy(X, y)
5957
end

perf/node_classification_cora_geometricflux.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,7 @@ function train(; kws...)
5858
GCNConv(g, nhidden => nhidden, relu),
5959
Dense(nhidden, nout)) |> device
6060

61-
ps = Flux.params(model)
62-
opt = Adam(args.η)
61+
opt = Flux.setup(Adam(args.η), model)
6362

6463
@info g
6564

@@ -73,12 +72,12 @@ function train(; kws...)
7372
## TRAINING
7473
report(0)
7574
for epoch in 1:(args.epochs)
76-
gs = Flux.gradient(ps) do
75+
grad = Flux.gradient(model) do model
7776
= model(X)
7877
logitcrossentropy(ŷ[:, train_ids], ytrain)
7978
end
8079

81-
Flux.Optimise.update!(opt, ps, gs)
80+
Flux.update!(opt, model, grad[1])
8281

8382
epoch % args.infotime == 0 && report(epoch)
8483
end

src/GNNGraphs/GNNGraphs.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ import Functors
2121
include("datastore.jl")
2222
export DataStore
2323

24+
include("abstracttypes.jl")
25+
export AbstractGNNGraph
26+
2427
include("gnngraph.jl")
2528
export GNNGraph,
2629
node_features,
@@ -30,7 +33,8 @@ export GNNGraph,
3033
include("gnnheterograph.jl")
3134
export GNNHeteroGraph,
3235
num_edge_types,
33-
num_node_types
36+
num_node_types,
37+
edge_type_subgraph
3438

3539
include("temporalsnapshotsgnngraph.jl")
3640
export TemporalSnapshotsGNNGraph,

src/GNNGraphs/abstracttypes.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
2+
const COO_T = Tuple{T, T, V} where {T <: AbstractVector{<:Integer}, V}
3+
const ADJLIST_T = AbstractVector{T} where {T <: AbstractVector{<:Integer}}
4+
const ADJMAT_T = AbstractMatrix
5+
const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T
6+
const CUMAT_T = Union{CUDA.AnyCuMatrix, CUDA.CUSPARSE.CuSparseMatrix}
7+
8+
const AVecI = AbstractVector{<:Integer}
9+
10+
# All concrete graph types should be subtypes of AbstractGNNGraph{T}.
11+
# GNNGraph and GNNHeteroGraph are the two concrete types.
12+
abstract type AbstractGNNGraph{T} <: AbstractGraph{Int} end

src/GNNGraphs/gnngraph.jl

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,10 @@
11
#===================================
2-
Define GNNGraph type as a subtype of Graphs' AbstractGraph.
2+
Define GNNGraph type as a subtype of Graphs.AbstractGraph.
33
For the core methods to be implemented by any AbstractGraph, see
44
https://juliagraphs.org/Graphs.jl/latest/types/#AbstractGraph-Type
55
https://juliagraphs.org/Graphs.jl/latest/developing/#Developing-Alternate-Graph-Types
66
=============================================#
77

8-
const COO_T = Tuple{T, T, V} where {T <: AbstractVector{<:Integer}, V}
9-
const ADJLIST_T = AbstractVector{T} where {T <: AbstractVector{<:Integer}}
10-
const ADJMAT_T = AbstractMatrix
11-
const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T
12-
const CUMAT_T = Union{CUDA.AnyCuMatrix, CUDA.CUSPARSE.CuSparseMatrix}
13-
14-
const AVecI = AbstractVector{<:Integer}
15-
168
"""
179
GNNGraph(data; [graph_type, ndata, edata, gdata, num_nodes, graph_indicator, dir])
1810
GNNGraph(g::GNNGraph; [ndata, edata, gdata])
@@ -113,7 +105,7 @@ g = g |> gpu
113105
source, target = edge_index(g)
114106
```
115107
"""
116-
struct GNNGraph{T <: Union{COO_T, ADJMAT_T}} <: AbstractGraph{Int}
108+
struct GNNGraph{T <: Union{COO_T, ADJMAT_T}} <: AbstractGNNGraph{T}
117109
graph::T
118110
num_nodes::Int
119111
num_edges::Int

src/GNNGraphs/gnnheterograph.jl

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ julia> hg.ndata[:A].x
8080
8181
See also [`GNNGraph`](@ref) for a homogeneous graph type and [`rand_heterograph`](@ref) for a function to generate random heterographs.
8282
"""
83-
struct GNNHeteroGraph{T <: Union{COO_T, ADJMAT_T}}
83+
struct GNNHeteroGraph{T <: Union{COO_T, ADJMAT_T}} <: AbstractGNNGraph{T}
8484
graph::EDict{T}
8585
num_nodes::NDict{Int}
8686
num_edges::EDict{Int}
@@ -225,3 +225,67 @@ For [`GNNHeteroGraph`](@ref)s, this is the number of unique node types.
225225
num_node_types(g::GNNGraph) = 1
226226

227227
num_node_types(g::GNNHeteroGraph) = length(g.ntypes)
228+
229+
230+
"""
231+
edge_type_subgraph(g::GNNHeteroGraph, edge_ts)
232+
233+
Return a subgraph of `g` that contains only the edges of type `edge_ts`.
234+
Edge types can be specified as a single edge type (i.e. a tuple containing 3 symbols) or a vector of edge types.
235+
"""
236+
edge_type_subgraph(g::GNNHeteroGraph, edge_t::EType) = edge_type_subgraph(g, [edge_t])
237+
238+
function edge_type_subgraph(g::GNNHeteroGraph, edge_ts::AbstractVector{<:EType})
239+
for edge_t in edge_ts
240+
@assert edge_t in g.etypes "Edge type $(edge_t) not found in graph"
241+
end
242+
node_ts = _ntypes_from_edges(edge_ts)
243+
graph = Dict(edge_t => g.graph[edge_t] for edge_t in edge_ts)
244+
num_nodes = Dict(node_t => g.num_nodes[node_t] for node_t in node_ts)
245+
num_edges = Dict(edge_t => g.num_edges[edge_t] for edge_t in edge_ts)
246+
if g.graph_indicator === nothing
247+
graph_indicator = nothing
248+
else
249+
graph_indicator = Dict(node_t => g.graph_indicator[node_t] for node_t in node_ts)
250+
end
251+
ndata = Dict(node_t => g.ndata[node_t] for node_t in node_ts if node_t in keys(g.ndata))
252+
edata = Dict(edge_t => g.edata[edge_t] for edge_t in edge_ts if edge_t in keys(g.edata))
253+
254+
return GNNHeteroGraph(graph, num_nodes, num_edges, g.num_graphs,
255+
graph_indicator, ndata, edata, g.gdata,
256+
node_ts, edge_ts)
257+
end
258+
259+
# TODO this is not correct but Zygote cannot differentiate
260+
# through dictionary generation
261+
@non_differentiable edge_type_subgraph(::Any...)
262+
263+
function _ntypes_from_edges(edge_ts::AbstractVector{<:EType})
264+
ntypes = Symbol[]
265+
for edge_t in edge_ts
266+
node1_t, _, node2_t = edge_t
267+
!in(node1_t, ntypes) && push!(ntypes, node1_t)
268+
!in(node2_t, ntypes) && push!(ntypes, node2_t)
269+
end
270+
return ntypes
271+
end
272+
273+
@non_differentiable _ntypes_from_edges(::Any...)
274+
275+
276+
function Base.getindex(g::GNNHeteroGraph, node_t::NType)
277+
if !haskey(g.ndata, node_t) && node_t in g.ntypes
278+
g.ndata[node_t] = DataStore(g.num_nodes[node_t])
279+
end
280+
return g.ndata[node_t]
281+
end
282+
283+
Base.setindex!(g::GNNHeteroGraph, node_t::NType, x) = g.ndata[node_t] = x
284+
285+
function Base.getindex(g::GNNHeteroGraph, edge_t::EType)
286+
if !haskey(g.edata, node_t) && edge_t in g.etypes
287+
g.ndata[node_t] = DataStore(g.num_edges[edge_t])
288+
end
289+
return g.edata[edge_t]
290+
end
291+
Base.setindex!(g::GNNHeteroGraph, edge_t::EType, x) = g.edata[edge_t] = x

src/GNNGraphs/query.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,15 @@ edge_index(g::GNNGraph{<:COO_T}) = g.graph[1:2]
1414
edge_index(g::GNNGraph{<:ADJMAT_T}) = to_coo(g.graph, num_nodes = g.num_nodes)[1][1:2]
1515

1616
""""
17-
edge_index(g::GNNHeteroGraph, edge_t)
17+
edge_index(g::GNNHeteroGraph, [edge_t])
1818
19-
Return a tuple containing two vectors, respectively storing the source and target nodes for each edges in `g` of type `edge_t = (:node1_t, :rel, :node2_t)`.
19+
Return a tuple containing two vectors, respectively storing the source and target nodes
20+
for each edges in `g` of type `edge_t = (:node1_t, :rel, :node2_t)`.
21+
22+
If `edge_t` is not provided, it will error if `g` has more than one edge type.
2023
"""
2124
edge_index(g::GNNHeteroGraph{<:COO_T}, edge_t::EType) = g.graph[edge_t][1:2]
25+
edge_index(g::GNNHeteroGraph{<:COO_T}) = only(g.graph)[2][1:2]
2226

2327
get_edge_weight(g::GNNGraph{<:COO_T}) = g.graph[3]
2428

src/GNNGraphs/transform.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -772,3 +772,4 @@ ci2t(ci::AbstractVector{<:CartesianIndex}, dims) = ntuple(i -> map(x -> x[i], ci
772772
@non_differentiable add_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule
773773
@non_differentiable remove_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule
774774
@non_differentiable dense_zeros_like(x...)
775+

0 commit comments

Comments
 (0)