Skip to content

Commit efb146a

Browse files
keep vectors
1 parent 670556c commit efb146a

File tree

3 files changed

+31
-21
lines changed

3 files changed

+31
-21
lines changed

src/GNNGraphs/gnngraph.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,6 @@ function Base.show(io::IO, g::GNNGraph)
206206
if !isempty(g.ndata)
207207
print(io, "\n ndata:")
208208
for k in keys(g.ndata)
209-
# print(io, "\n $k => $(size(g.ndata[k]))")
210209
print(io, "\n $k => $(summary(g.ndata[k]))")
211210
end
212211
end

src/GNNGraphs/utils.jl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,16 +66,17 @@ function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_nee
6666
# If last array dimension is not 1, add a new dimension.
6767
# This is mostly useful to reshape global feature vectors
6868
# of size D to Dx1 matrices.
69-
function unsqz(v)
70-
if v isa AbstractArray && size(v)[end] != 1
71-
v = reshape(v, size(v)..., 1)
72-
end
73-
v
74-
end
69+
unsqz_last(v::AbstractArray) = size(v)[end] != 1 ? reshape(v, size(v)..., 1) : v
70+
unsqz_last(v) = v
7571

76-
data = NamedTuple{keys(data)}(unsqz.(values(data)))
72+
data = map(unsqz_last, data)
7773
end
78-
74+
75+
## Turn vectors in 1 x n matrices.
76+
# unsqz_first(v::AbstractVector) = reshape(v, 1, length(v))
77+
# unsqz_first(v) = v
78+
# data = map(unsqz_first, data)
79+
7980
sz = map(x -> x isa AbstractArray ? size(x)[end] : 0, data)
8081

8182
if duplicate_if_needed
@@ -88,10 +89,11 @@ function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_nee
8889
end
8990
v
9091
end
91-
data = NamedTuple{keys(data)}(duplicate.(values(data)))
92-
else
93-
@assert all(x -> x == 0 || x == n, sz) "Wrong size in last dimension for feature array."
92+
data = map(duplicate, data)
9493
end
94+
95+
@assert all(x -> x == 0 || x == n, sz) "Wrong size in last dimension for feature array."
96+
9597
return data
9698
end
9799

test/GNNGraphs/gnngraph.jl

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -225,15 +225,24 @@
225225
g = GNNGraph(erdos_renyi(10, 30), edata=e, graph_type=GRAPH_T)
226226
@test g.edata.e == [e; e]
227227

228-
229-
# Attach non array data
230-
g = GNNGraph(erdos_renyi(10, 30), edata="ciao", graph_type=GRAPH_T)
231-
@test g.edata.e == "ciao"
232-
233-
234-
# Wrong need number of features
235-
g = GNNGraph(erdos_renyi(10, 30), edata="ciao", graph_type=GRAPH_T)
236-
228+
# non-array global
229+
g = rand_graph(10, 30, gdata="ciao", graph_type=GRAPH_T)
230+
@test g.gdata.u == "ciao"
231+
232+
# vectors stays vectors
233+
g = rand_graph(10, 30, ndata=rand(10),
234+
edata=rand(30),
235+
gdata=(u=rand(2), z=rand(1), q=1),
236+
graph_type=GRAPH_T)
237+
@test size(g.ndata.x) == (10,)
238+
@test size(g.edata.e) == (30,)
239+
@test size(g.gdata.u) == (2, 1)
240+
@test size(g.gdata.z) == (1,)
241+
@test g.gdata.q === 1
242+
243+
# Error for non-array ndata
244+
@test_throws AssertionError rand_graph(10, 30, ndata="ciao", graph_type=GRAPH_T)
245+
@test_throws AssertionError rand_graph(10, 30, ndata=1, graph_type=GRAPH_T)
237246
end
238247

239248
@testset "LearnBase and DataLoader compat" begin

0 commit comments

Comments
 (0)