Skip to content

Commit e6d8d48

Browse files
better inference for num_nodes
1 parent 2883214 commit e6d8d48

File tree

4 files changed

+37
-9
lines changed

4 files changed

+37
-9
lines changed

GNNGraphs/src/gnngraph.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ function GNNGraph(data::D;
149149

150150
# don't force the shape of the data when there is only one graph
151151
gdata = normalize_graphdata(gdata, default_name = :u,
152-
n = num_graphs > 1 ? num_graphs : -1)
152+
n = num_graphs > 1 ? num_graphs : -1, glob=true)
153153

154154
GNNGraph(graph,
155155
num_nodes, num_edges, num_graphs,
@@ -201,10 +201,10 @@ end
201201

202202
function GNNGraph(g::GNNGraph; ndata = g.ndata, edata = g.edata, gdata = g.gdata,
203203
graph_type = nothing)
204-
ndata = normalize_graphdata(ndata, default_name = :x, n = g.num_nodes)
205-
edata = normalize_graphdata(edata, default_name = :e, n = g.num_edges,
206-
duplicate_if_needed = true)
207-
gdata = normalize_graphdata(gdata, default_name = :u, n = g.num_graphs)
204+
ndata = normalize_graphdata(ndata, default_name=:x, n=g.num_nodes)
205+
edata = normalize_graphdata(edata, default_name=:e, n=g.num_edges,
206+
duplicate_if_needed=true)
207+
gdata = normalize_graphdata(gdata, default_name=:u, n=g.num_graphs, glob=true)
208208

209209
if !isnothing(graph_type)
210210
if graph_type == :coo

GNNGraphs/src/gnnheterograph/gnnheterograph.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ function GNNHeteroGraph(data::EDict;
145145
edata = normalize_heterographdata(edata, default_name = :e, ns = num_edges,
146146
duplicate_if_needed = true)
147147
gdata = normalize_graphdata(gdata, default_name = :u,
148-
n = num_graphs > 1 ? num_graphs : -1)
148+
n = num_graphs > 1 ? num_graphs : -1, glob=true)
149149
end
150150

151151
return GNNHeteroGraph(graph,

GNNGraphs/src/utils.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,16 +129,16 @@ function normalize_graphdata(data; default_name::Symbol, kws...)
129129
normalize_graphdata(NamedTuple{(default_name,)}((data,)); default_name, kws...)
130130
end
131131

132-
function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_needed = false)
132+
function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_needed = false, glob=false)
133133
# This had to workaround two Zygote bugs with NamedTuples
134134
# https://github.com/FluxML/Zygote.jl/issues/1071
135-
# https://github.com/FluxML/Zygote.jl/issues/1072
135+
# https://github.com/FluxML/Zygote.jl/issues/1072 # TODO this is fixed
136136

137137
if n > 1
138138
@assert all(x -> x isa AbstractArray, data) "Non-array features provided."
139139
end
140140

141-
if n <= 1
141+
if n <= 1 && glob
142142
# If last array dimension is not 1, add a new dimension.
143143
# This is mostly useful to reshape global feature vectors
144144
# of size D to Dx1 matrices.

GNNGraphs/test/gnngraph.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,34 @@ end
5353
end
5454
end
5555

56+
@testitem "Constructor: empty" setup=[GraphsTestModule] begin
57+
g = GNNGraph(ndata=ones(2, 1))
58+
@test g.num_nodes == 1
59+
@test g.num_edges == 0
60+
@test g.ndata.x == ones(2, 1)
61+
62+
g = GNNGraph(num_nodes=1)
63+
@test g.num_nodes == 1
64+
@test g.num_edges == 0
65+
@test isempty(g.ndata)
66+
67+
g = GNNGraph((Int[], Int[]); ndata=(;a=[1]))
68+
@test g.num_nodes == 1
69+
@test g.num_edges == 0
70+
@test g.ndata.a == [1]
71+
72+
g = GNNGraph((Int[], Int[]); ndata=(;a=[1]), edata=(;b=Int[]), num_nodes=1)
73+
@test g.num_nodes == 1
74+
@test g.num_edges == 0
75+
@test g.ndata.a == [1]
76+
@test g.edata.b == Int[]
77+
78+
g = GNNGraph(; edata=(;b=Int[]))
79+
@test g.num_nodes == 0
80+
@test g.num_edges == 0
81+
@test g.edata.b == Int[]
82+
end
83+
5684
@testitem "symmetric graph" setup=[GraphsTestModule] tags=[:gpu] begin
5785
using .GraphsTestModule
5886
dev = gpu_device()

0 commit comments

Comments
 (0)