Skip to content

Commit 7aa229b

Browse files
extend reduce_nodes and graph_indicator for heterographs (#339)
* graph indicator and reduce nodes * more tests
1 parent 2f2be1a commit 7aa229b

File tree

6 files changed

+70
-12
lines changed

6 files changed

+70
-12
lines changed

docs/src/heterograph.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
Heterogeneous graphs (also called heterographs), are graphs where each node has a type,
44
that we denote with symbols such as `:user` and `:movie`.
5-
Also edges have a type, such as `:rate` or `:like`, and they can connect nodes of different types. We call a triplet `(source_node_type, edge_type, target_node_type)` the type of a *relation*, e.g. `(:user, :rate, :movie)`.
5+
Releations such as `:rate` or `:like` can connect nodes of different types. We call a triplet `(source_node_type, relation_type, target_node_type)` the type of a edge, e.g. `(:user, :rate, :movie)`.
66

77
Different node/edge types can store different groups of features
88
and this makes heterographs a very flexible modeling tools
@@ -12,7 +12,7 @@ the type [`GNNHeteroGraph`](@ref).
1212

1313
## Creating a Heterograph
1414

15-
A heterograph can be created by passing pairs of relation type and data to the constructor.
15+
A heterograph can be created by passing pairs `edge_type => data` to the constructor.
1616
```jldoctest
1717
julia> g = GNNHeteroGraph((:user, :like, :actor) => ([1,2,2,3], [1,3,2,9]),
1818
(:user, :rate, :movie) => ([1,1,2,3], [7,13,5,7]))

src/GNNGraphs/query.jl

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ edge_index(g::GNNGraph{<:COO_T}) = g.graph[1:2]
1313

1414
edge_index(g::GNNGraph{<:ADJMAT_T}) = to_coo(g.graph, num_nodes = g.num_nodes)[1][1:2]
1515

16-
""""
16+
"""
1717
edge_index(g::GNNHeteroGraph, [edge_t])
1818
1919
Return a tuple containing two vectors, respectively storing the source and target nodes
20-
for each edges in `g` of type `edge_t = (:node1_t, :rel, :node2_t)`.
20+
for each edges in `g` of type `edge_t = (src_t, rel_t, trg_t)`.
2121
2222
If `edge_t` is not provided, it will error if `g` has more than one edge type.
2323
"""
@@ -454,12 +454,13 @@ _rand_dense_vector(A::AbstractMatrix{T}) where {T} = randn(float(T), size(A, 1))
454454
# https://discourse.julialang.org/t/cuda-eigenvalues-of-a-sparse-matrix/46851/5
455455

456456
"""
457-
graph_indicator(g)
457+
graph_indicator(g::GNNGraph; edges=false)
458458
459459
Return a vector containing the graph membership
460460
(an integer from `1` to `g.num_graphs`) of each node in the graph.
461+
If `edges=true`, return the graph membership of each edge instead.
461462
"""
462-
function graph_indicator(g; edges = false)
463+
function graph_indicator(g::GNNGraph; edges = false)
463464
if isnothing(g.graph_indicator)
464465
gi = ones_like(edge_index(g)[1], Int, g.num_nodes)
465466
else
@@ -473,6 +474,29 @@ function graph_indicator(g; edges = false)
473474
end
474475
end
475476

477+
"""
478+
graph_indicator(g::GNNHeteroGraph, [node_t])
479+
480+
Return a Dict of vectors containing the graph membership
481+
(an integer from `1` to `g.num_graphs`) of each node in the graph for each node type.
482+
If `node_t` is provided, return the graph membership of each node of type `node_t` instead.
483+
484+
See also [`batch`](@ref).
485+
"""
486+
function graph_indicator(g::GNNHeteroGraph)
487+
return g.graph_indicator
488+
end
489+
490+
function graph_indicator(g::GNNHeteroGraph, node_t::Symbol)
491+
@assert node_t g.ntypes
492+
if isnothing(g.graph_indicator)
493+
gi = ones_like(edge_index(g, first(g.etypes))[1], Int, g.num_nodes[node_t])
494+
else
495+
gi = g.graph_indicator[node_t]
496+
end
497+
return gi
498+
end
499+
476500
function node_features(g::GNNGraph)
477501
if isempty(g.ndata)
478502
return nothing

src/GNNGraphs/transform.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -193,15 +193,16 @@ function add_edges(g::GNNGraph{<:COO_T}, data::COO_T; edata = nothing)
193193
end
194194

195195
"""
196-
add_edges(g::GNNHeteroGraph, rel_t, s, t; [edata, num_nodes])
197-
add_edges(g::GNNHeteroGraph, rel_t => (s, t); [edata, num_nodes])
198-
add_edges(g::GNNHeteroGraph, rel_t => (s, t, w); [edata, num_nodes])
196+
add_edges(g::GNNHeteroGraph, edge_t, s, t; [edata, num_nodes])
197+
add_edges(g::GNNHeteroGraph, edge_t => (s, t); [edata, num_nodes])
198+
add_edges(g::GNNHeteroGraph, edge_t => (s, t, w); [edata, num_nodes])
199199
200-
Add to heterograph `g` the relation of type `rel_t` with source node vector `s` and target node vector `t`.
200+
Add to heterograph `g` edges of type `edge_t` with source node vector `s` and target node vector `t`.
201201
Optionally, pass the edge weights `w` or the features `edata` for the new edges.
202-
`rel_t` is a triplet of symbols `(src_t, edge_t, dst_t)`.
202+
`edge_t` is a triplet of symbols `(src_t, rel_t, dst_t)`.
203203
204-
If the relation is not already present in the graph, it is added. If it involves new node types, they are added to the graph as well.
204+
If the edge type is not already present in the graph, it is added.
205+
If it involves new node types, they are added to the graph as well.
205206
In this case, a dictionary or named tuple of `num_nodes` can be passed to specify the number of nodes of the new types,
206207
otherwise the number of nodes is inferred from the maximum node id in `s` and `t`.
207208
"""

src/utils.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,27 @@ ofeltype(x, y) = convert(float(eltype(x)), y)
66
For a batched graph `g`, return the graph-wise aggregation of the node
77
features `x`. The aggregation operator `aggr` can be `+`, `mean`, `max`, or `min`.
88
The returned array will have last dimension `g.num_graphs`.
9+
10+
See also: [`reduce_edges`](@ref).
911
"""
1012
function reduce_nodes(aggr, g::GNNGraph, x)
1113
@assert size(x)[end] == g.num_nodes
1214
indexes = graph_indicator(g)
1315
return NNlib.scatter(aggr, x, indexes)
1416
end
1517

18+
"""
19+
reduce_nodes(aggr, indicator::AbstractVector, x)
20+
21+
Return the graph-wise aggregation of the node features `x` given the
22+
graph indicator `indicator`. The aggregation operator `aggr` can be `+`, `mean`, `max`, or `min`.
23+
24+
See also [`graph_indicator`](@ref).
25+
"""
26+
function reduce_nodes(aggr, indicator::AbstractVector, x)
27+
return NNlib.scatter(aggr, x, indicator)
28+
end
29+
1630
"""
1731
reduce_edges(aggr, g, e)
1832

test/GNNGraphs/query.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,3 +232,19 @@ end
232232
@test eltype(khop_adj(g, 10, Float32)) == Float32
233233
end
234234
end
235+
236+
if GRAPH_T == :coo
237+
@testset "HeteroGraph" begin
238+
@testset "graph_indicator" begin
239+
gs = [rand_heterograph(Dict(:user => 10, :movie => 20, :actor => 30),
240+
Dict((:user,:like,:movie) => 10,
241+
(:actor,:rate,:movie)=>20)) for _ in 1:3]
242+
g = MLUtils.batch(gs)
243+
@test graph_indicator(g) == Dict(:user => [repeat([1], 10); repeat([2], 10); repeat([3], 10)],
244+
:movie => [repeat([1], 20); repeat([2], 20); repeat([3], 20)],
245+
:actor => [repeat([1], 30); repeat([2], 30); repeat([3], 30)])
246+
@test graph_indicator(g, :movie) == [repeat([1], 20); repeat([2], 20); repeat([3], 20)]
247+
end
248+
end
249+
end
250+

test/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ e = g.edata.e
1010
r = reduce_nodes(mean, g, x)
1111
@test size(r) == (Dx, g.num_graphs)
1212
@test r[:, 2] mean(getgraph(g, 2).ndata.x, dims = 2)
13+
14+
r2 = reduce_nodes(mean, graph_indicator(g), x)
15+
@test r2 == r
1316
end
1417

1518
@testset "reduce_edges" begin

0 commit comments

Comments
 (0)