Skip to content

Commit 32ca15a

Browse files
feat: Add empty constructor for GNNHeteroGraph (#358)
* add empty heterograph constructor * update docs * Update src/GNNGraphs/convert.jl Co-authored-by: Carlo Lucibello <[email protected]> * Update docs/src/heterograph.md Co-authored-by: Carlo Lucibello <[email protected]> * add tests --------- Co-authored-by: Carlo Lucibello <[email protected]>
1 parent 2c11b95 commit 32ca15a

File tree

4 files changed

+49
-23
lines changed

4 files changed

+49
-23
lines changed

docs/src/heterograph.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,13 @@ the type [`GNNHeteroGraph`](@ref).
1212

1313
## Creating a Heterograph
1414

15-
A heterograph can be created by passing pairs `edge_type => data` to the constructor.
15+
A heterograph can be created empty or by passing pairs `edge_type => data` to the constructor.
1616
```jldoctest
17+
julia> g = GNNHeteroGraph()
18+
GNNHeteroGraph:
19+
num_nodes: Dict()
20+
num_edges: Dict()
21+
1722
julia> g = GNNHeteroGraph((:user, :like, :actor) => ([1,2,2,3], [1,3,2,9]),
1823
(:user, :rate, :movie) => ([1,1,2,3], [7,13,5,7]))
1924
GNNHeteroGraph:

src/GNNGraphs/convert.jl

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,28 @@ function to_coo(data::EDict; num_nodes = nothing, kws...)
44
graph = EDict{COO_T}()
55
_num_nodes = NDict{Int}()
66
num_edges = EDict{Int}()
7-
for k in keys(data)
8-
d = data[k]
9-
@assert d isa Tuple
10-
if length(d) == 2
11-
d = (d..., nothing)
7+
if !isempty(data)
8+
for k in keys(data)
9+
d = data[k]
10+
@assert d isa Tuple
11+
if length(d) == 2
12+
d = (d..., nothing)
13+
end
14+
if num_nodes !== nothing
15+
n1 = get(num_nodes, k[1], nothing)
16+
n2 = get(num_nodes, k[3], nothing)
17+
else
18+
n1 = nothing
19+
n2 = nothing
20+
end
21+
g, nnodes, nedges = to_coo(d; hetero = true, num_nodes = (n1, n2), kws...)
22+
graph[k] = g
23+
num_edges[k] = nedges
24+
_num_nodes[k[1]] = max(get(_num_nodes, k[1], 0), nnodes[1])
25+
_num_nodes[k[3]] = max(get(_num_nodes, k[3], 0), nnodes[2])
1226
end
13-
if num_nodes !== nothing
14-
n1 = get(num_nodes, k[1], nothing)
15-
n2 = get(num_nodes, k[3], nothing)
16-
else
17-
n1 = nothing
18-
n2 = nothing
19-
end
20-
g, nnodes, nedges = to_coo(d; hetero = true, num_nodes = (n1, n2), kws...)
21-
graph[k] = g
22-
num_edges[k] = nedges
23-
_num_nodes[k[1]] = max(get(_num_nodes, k[1], 0), nnodes[1])
24-
_num_nodes[k[3]] = max(get(_num_nodes, k[3], 0), nnodes[2])
27+
graph = Dict([k => v for (k, v) in pairs(graph)]...) # try to restrict the key/value types
2528
end
26-
graph = Dict([k => v for (k, v) in pairs(graph)]...) # try to restrict the key/value types
2729
return graph, _num_nodes, num_edges
2830
end
2931

src/GNNGraphs/gnnheterograph.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ end
100100
GNNHeteroGraph(data; kws...) = GNNHeteroGraph(Dict(data); kws...)
101101
GNNHeteroGraph(data::Pair...; kws...) = GNNHeteroGraph(Dict(data...); kws...)
102102

103+
GNNHeteroGraph() = GNNHeteroGraph(Dict{Tuple{Symbol,Symbol,Symbol}, Any}())
104+
103105
function GNNHeteroGraph(data::Dict; kws...)
104106
all(k -> k isa EType, keys(data)) || throw(ArgumentError("Keys of data must be tuples of the form `(source_type, edge_type, target_type)`"))
105107
return GNNHeteroGraph(Dict([k => v for (k, v) in pairs(data)]...); kws...)
@@ -135,10 +137,17 @@ function GNNHeteroGraph(data::EDict;
135137
num_graphs = !isnothing(graph_indicator) ?
136138
maximum([maximum(gi) for gi in values(graph_indicator)]) : 1
137139

138-
ndata = normalize_heterographdata(ndata, default_name = :x, ns = num_nodes)
139-
edata = normalize_heterographdata(edata, default_name = :e, ns = num_edges,
140-
duplicate_if_needed = true)
141-
gdata = normalize_graphdata(gdata, default_name = :u, n = num_graphs)
140+
141+
if length(keys(graph)) == 0
142+
ndata = Dict{Symbol, DataStore}()
143+
edata = Dict{Tuple{Symbol, Symbol, Symbol}, DataStore}()
144+
gdata = DataStore()
145+
else
146+
ndata = normalize_heterographdata(ndata, default_name = :x, ns = num_nodes)
147+
edata = normalize_heterographdata(edata, default_name = :e, ns = num_edges,
148+
duplicate_if_needed = true)
149+
gdata = normalize_graphdata(gdata, default_name = :u, n = num_graphs)
150+
end
142151

143152
return GNNHeteroGraph(graph,
144153
num_nodes, num_edges, num_graphs,

test/GNNGraphs/gnnheterograph.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
11

2+
3+
@testset "Empty constructor" begin
4+
g = GNNHeteroGraph()
5+
@test isempty(g.num_nodes)
6+
g = add_edges(g, (:user, :like, :actor) => ([1,2,3,3,3], [3,5,1,9,4]))
7+
@test g.num_nodes[:user] == 3
8+
@test g.num_nodes[:actor] == 9
9+
@test g.num_edges[(:user, :like, :actor)] == 5
10+
end
11+
212
@testset "Constructor from pairs" begin
313
hg = GNNHeteroGraph((:A, :e1, :B) => ([1,2,3,4], [3,2,1,5]))
414
@test hg.num_nodes == Dict(:A => 4, :B => 5)

0 commit comments

Comments
 (0)