Skip to content

Commit 04d026b

Browse files
Merge pull request #145 from CarloLucibello/cl/nodecheck
better printing for GNNGraphs and check that features are arrays
2 parents a82a320 + 9ba9f5f commit 04d026b

File tree

3 files changed

+39
-23
lines changed

3 files changed

+39
-23
lines changed

src/GNNGraphs/gnngraph.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,19 +206,19 @@ function Base.show(io::IO, g::GNNGraph)
206206
if !isempty(g.ndata)
207207
print(io, "\n ndata:")
208208
for k in keys(g.ndata)
209-
print(io, "\n $k => $(size(g.ndata[k]))")
209+
print(io, "\n $k => $(summary(g.ndata[k]))")
210210
end
211211
end
212212
if !isempty(g.edata)
213213
print(io, "\n edata:")
214214
for k in keys(g.edata)
215-
print(io, "\n $k => $(size(g.edata[k]))")
215+
print(io, "\n $k => $(summary(g.edata[k]))")
216216
end
217217
end
218218
if !isempty(g.gdata)
219219
print(io, "\n gdata:")
220220
for k in keys(g.gdata)
221-
print(io, "\n $k => $(size(g.gdata[k]))")
221+
print(io, "\n $k => $(summary(g.gdata[k]))")
222222
end
223223
end
224224
end

src/GNNGraphs/utils.jl

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -57,36 +57,38 @@ function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_nee
5757
# This had to workaround two Zygote bugs with NamedTuples
5858
# https://github.com/FluxML/Zygote.jl/issues/1071
5959
# https://github.com/FluxML/Zygote.jl/issues/1072
60+
61+
if n != 1
62+
@assert all(x -> x isa AbstractArray, data) "Non-array features provided."
63+
end
6064

6165
if n == 1
6266
# If last array dimension is not 1, add a new dimension.
63-
# This is mostly usefule to reshape globale feature vectors
67+
# This is mostly useful to reshape global feature vectors
6468
# of size D to Dx1 matrices.
65-
function unsqz(v)
66-
if v isa AbstractArray && size(v)[end] != 1
67-
v = reshape(v, size(v)..., 1)
68-
end
69-
v
70-
end
69+
unsqz_last(v::AbstractArray) = size(v)[end] != 1 ? reshape(v, size(v)..., 1) : v
70+
unsqz_last(v) = v
7171

72-
data = NamedTuple{keys(data)}(unsqz.(values(data)))
72+
data = map(unsqz_last, data)
7373
end
74-
75-
sz = map(x -> x isa AbstractArray ? size(x)[end] : 0, data)
76-
if duplicate_if_needed
77-
# Used to copy edge features on reverse edges
78-
@assert all(s -> s == 0 || s == n || s == n÷2, sz) "Wrong size in last dimension for feature array."
7974

75+
## Turn vectors in 1 x n matrices.
76+
# unsqz_first(v::AbstractVector) = reshape(v, 1, length(v))
77+
# unsqz_first(v) = v
78+
# data = map(unsqz_first, data)
79+
80+
if duplicate_if_needed
8081
function duplicate(v)
8182
if v isa AbstractArray && size(v)[end] == n÷2
8283
v = cat(v, v, dims=ndims(v))
8384
end
8485
v
8586
end
86-
data = NamedTuple{keys(data)}(duplicate.(values(data)))
87-
else
88-
@assert all(s -> s == 0 || s == n, sz) "Wrong size in last dimension for feature array."
87+
data = map(duplicate, data)
8988
end
89+
90+
@assert all(x -> x isa AbstractArray ? size(x)[end] == n : true, data) "Wrong size in last dimension for feature array."
91+
9092
return data
9193
end
9294

test/GNNGraphs/gnngraph.jl

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,10 +225,24 @@
225225
g = GNNGraph(erdos_renyi(10, 30), edata=e, graph_type=GRAPH_T)
226226
@test g.edata.e == [e; e]
227227

228-
229-
# Attach non array data
230-
g = GNNGraph(erdos_renyi(10, 30), edata="ciao", graph_type=GRAPH_T)
231-
@test g.edata.e == "ciao"
228+
# non-array global
229+
g = rand_graph(10, 30, gdata="ciao", graph_type=GRAPH_T)
230+
@test g.gdata.u == "ciao"
231+
232+
# vectors stays vectors
233+
g = rand_graph(10, 30, ndata=rand(10),
234+
edata=rand(30),
235+
gdata=(u=rand(2), z=rand(1), q=1),
236+
graph_type=GRAPH_T)
237+
@test size(g.ndata.x) == (10,)
238+
@test size(g.edata.e) == (30,)
239+
@test size(g.gdata.u) == (2, 1)
240+
@test size(g.gdata.z) == (1,)
241+
@test g.gdata.q === 1
242+
243+
# Error for non-array ndata
244+
@test_throws AssertionError rand_graph(10, 30, ndata="ciao", graph_type=GRAPH_T)
245+
@test_throws AssertionError rand_graph(10, 30, ndata=1, graph_type=GRAPH_T)
232246
end
233247

234248
@testset "LearnBase and DataLoader compat" begin

0 commit comments

Comments
 (0)