Skip to content

Commit 8979643

Browse files
better printing and chech feature types
1 parent a82a320 commit 8979643

File tree

5 files changed

+63
-6
lines changed

5 files changed

+63
-6
lines changed

prova.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using GraphNeuralNetworks, Test, Flux
2+
using BenchmarkTools
3+
using ProfileView
4+
5+
g = rand_graph(10, 30, ndata=rand(Float32, 2, 10))
6+
l = GATConv(2 => 2)
7+
y = l(g, g.ndata.x)
8+
@assert eltype(y) == Float32
9+
10+
dx = gradient(x -> sum(sin.(l(g, x))), g.ndata.x)[1]
11+
@assert eltype(dx) == Float32
12+
13+
struct B
14+
slope
15+
end
16+
17+
(a::B)(x) = leakyrelu(x, a.slope)
18+
19+
a = B(0.3f0)
20+
grad = gradient(a -> a(0.1f0), a)[1]

prova2.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
using Flux
2+
using GraphNeuralNetworks
3+
using BenchmarkTools
4+
using ProfileView
5+
6+
f(x) = 1
7+
8+
function test(g)
9+
loader = Flux.DataLoader(g, batchsize=100, shuffle=true)
10+
s = 0
11+
for d in loader
12+
s += f(d)
13+
end
14+
return s
15+
end
16+
17+
n = 5000
18+
s = 10
19+
data = [rand_graph(s, s, ndata = rand(1, s)) for i in 1:n]
20+
x1 = Flux.batch(data)
21+
x2 = Flux.batch([rand(s + s + s + s) for i in 1:n]) #source+target+data+extra
22+
23+
# @profview test(x1)
24+
@btime test($x1); # 1.295 s (2502 allocations: 6.17 MiB)
25+
@btime test($x2); # 357.595 μs (152 allocations: 1.61 MiB)
26+
@btime test($data); # 65.288 ms (227002 allocations: 27.00 MiB) # this PR

src/GNNGraphs/gnngraph.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,19 +206,20 @@ function Base.show(io::IO, g::GNNGraph)
206206
if !isempty(g.ndata)
207207
print(io, "\n ndata:")
208208
for k in keys(g.ndata)
209-
print(io, "\n $k => $(size(g.ndata[k]))")
209+
# print(io, "\n $k => $(size(g.ndata[k]))")
210+
print(io, "\n $k => $(summary(g.ndata[k]))")
210211
end
211212
end
212213
if !isempty(g.edata)
213214
print(io, "\n edata:")
214215
for k in keys(g.edata)
215-
print(io, "\n $k => $(size(g.edata[k]))")
216+
print(io, "\n $k => $(summary(g.edata[k]))")
216217
end
217218
end
218219
if !isempty(g.gdata)
219220
print(io, "\n gdata:")
220221
for k in keys(g.gdata)
221-
print(io, "\n $k => $(size(g.gdata[k]))")
222+
print(io, "\n $k => $(summary(g.gdata[k]))")
222223
end
223224
end
224225
end

src/GNNGraphs/utils.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,14 @@ function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_nee
5757
# This had to workaround two Zygote bugs with NamedTuples
5858
# https://github.com/FluxML/Zygote.jl/issues/1071
5959
# https://github.com/FluxML/Zygote.jl/issues/1072
60+
61+
if n != 1
62+
@assert all(x -> x isa AbstractArray, data) "Non-array features provided."
63+
end
6064

6165
if n == 1
6266
# If last array dimension is not 1, add a new dimension.
63-
# This is mostly usefule to reshape globale feature vectors
67+
# This is mostly useful to reshape global feature vectors
6468
# of size D to Dx1 matrices.
6569
function unsqz(v)
6670
if v isa AbstractArray && size(v)[end] != 1
@@ -73,19 +77,20 @@ function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_nee
7377
end
7478

7579
sz = map(x -> x isa AbstractArray ? size(x)[end] : 0, data)
80+
7681
if duplicate_if_needed
7782
# Used to copy edge features on reverse edges
7883
@assert all(s -> s == 0 || s == n || s == n÷2, sz) "Wrong size in last dimension for feature array."
7984

8085
function duplicate(v)
81-
if v isa AbstractArray && size(v)[end] == n÷2
86+
if size(v)[end] == n÷2
8287
v = cat(v, v, dims=ndims(v))
8388
end
8489
v
8590
end
8691
data = NamedTuple{keys(data)}(duplicate.(values(data)))
8792
else
88-
@assert all(s -> s == 0 || s == n, sz) "Wrong size in last dimension for feature array."
93+
@assert all(x -> x == 0 || x == n, sz) "Wrong size in last dimension for feature array."
8994
end
9095
return data
9196
end

test/GNNGraphs/gnngraph.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,11 @@
229229
# Attach non array data
230230
g = GNNGraph(erdos_renyi(10, 30), edata="ciao", graph_type=GRAPH_T)
231231
@test g.edata.e == "ciao"
232+
233+
234+
# Wrong need number of features
235+
g = GNNGraph(erdos_renyi(10, 30), edata="ciao", graph_type=GRAPH_T)
236+
232237
end
233238

234239
@testset "LearnBase and DataLoader compat" begin

0 commit comments

Comments
 (0)