Skip to content

Commit f72fbbb

Browse files
improve degree and more test fixes (#268)
* fix gpu test + test rearrangment edge_weight in rand_graph drop julia 1.6 improve degree * fix the rules * fix * rrule for adjacency matrix * cleanup * cleanup * ops * all tests passing also on gpu
1 parent 7397577 commit f72fbbb

File tree

14 files changed

+240
-87
lines changed

14 files changed

+240
-87
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
fail-fast: false
1515
matrix:
1616
version:
17-
- '1.6' # Replace this with the minimum Julia version that your package supports. E.g. if your package requires Julia 1.5 or higher, change this to '1.5'.
17+
- '1.7' # Replace this with the minimum Julia version that your package supports. E.g. if your package requires Julia 1.5 or higher, change this to '1.5'.
1818
- '1' # Leave this line unchanged. '1' will automatically expand to the latest stable 1.x release of Julia.
1919
- 'nightly'
2020
os:

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ NNlibCUDA = "0.2"
4141
NearestNeighbors = "0.4"
4242
Reexport = "1"
4343
StatsBase = "0.33"
44-
julia = "1.6"
44+
julia = "1.7"
4545

4646
[extras]
4747
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/GNNGraphs/generate.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
"""
2-
rand_graph(n, m; bidirected=true, seed=-1, kws...)
2+
rand_graph(n, m; bidirected=true, seed=-1, edge_weight = nothing, kws...)
33
4-
Generate a random (Erdós-Renyi) `GNNGraph` with `n` nodes
5-
and `m` edges.
4+
Generate a random (Erdós-Renyi) `GNNGraph` with `n` nodes and `m` edges.
65
76
If `bidirected=true` the reverse edge of each edge will be present.
87
If `bidirected=false` instead, `m` unrelated edges are generated.
98
In any case, the output graph will contain no self-loops or multi-edges.
109
10+
A vector can be passed as `edge_weight`. Its length has to be equal to `m`
11+
in the directed case, and `m÷2` in the bidirected one.
12+
1113
Use a `seed > 0` for reproducibility.
1214
1315
Additional keyword arguments will be passed to the [`GNNGraph`](@ref) constructor.
@@ -37,12 +39,12 @@ julia> edge_index(g)
3739
3840
```
3941
"""
40-
function rand_graph(n::Integer, m::Integer; bidirected = true, seed = -1, kws...)
42+
function rand_graph(n::Integer, m::Integer; bidirected = true, seed = -1, edge_weight = nothing, kws...)
4143
if bidirected
4244
@assert iseven(m) "Need even number of edges for bidirected graphs, given m=$m."
4345
end
4446
m2 = bidirected ? m ÷ 2 : m
45-
return GNNGraph(Graphs.erdos_renyi(n, m2; is_directed = !bidirected, seed); kws...)
47+
return GNNGraph(Graphs.erdos_renyi(n, m2; is_directed = !bidirected, seed); edge_weight, kws...)
4648
end
4749

4850
"""
@@ -68,7 +70,7 @@ GNNHeteroGraph:
6870
num_edges: ((:user, :rate, :movie) => 30,)
6971
```
7072
"""
71-
rand_heteropraph
73+
function rand_heteropraph end
7274

7375
# for generic iterators of pairs
7476
rand_heterograph(n, m; kws...) = rand_heterograph(Dict(n), Dict(m); kws...)

src/GNNGraphs/gnngraph.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,15 +176,20 @@ GNNGraph((s, t)::NTuple{2}; kws...) = GNNGraph((s, t, nothing); kws...)
176176

177177
# GNNGraph(g::AbstractGraph; kws...) = GNNGraph(adjacency_matrix(g, dir=:out); kws...)
178178

179-
function GNNGraph(g::AbstractGraph; kws...)
179+
function GNNGraph(g::AbstractGraph; edge_weight = nothing, kws...)
180180
s = Graphs.src.(Graphs.edges(g))
181181
t = Graphs.dst.(Graphs.edges(g))
182+
w = edge_weight
182183
if !Graphs.is_directed(g)
183184
# add reverse edges since GNNGraph is directed
184185
s, t = [s; t], [t; s]
186+
if !isnothing(w)
187+
@assert length(w) == Graphs.ne(g) "edge_weight must have length equal to the number of undirected edges"
188+
w = [w; w]
189+
end
185190
end
186191
num_nodes::Int = Graphs.nv(g)
187-
GNNGraph((s, t); num_nodes = num_nodes, kws...)
192+
GNNGraph((s, t, w); num_nodes = num_nodes, kws...)
188193
end
189194

190195
function GNNGraph(g::GNNGraph; ndata = g.ndata, edata = g.edata, gdata = g.gdata,

src/GNNGraphs/query.jl

Lines changed: 121 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ If `weighted=true`, the `A` will contain the edge weights if any, otherwise the
144144
function Graphs.adjacency_matrix(g::GNNGraph{<:COO_T}, T::DataType = eltype(g); dir = :out,
145145
weighted = true)
146146
if g.graph[1] isa CuVector
147-
# TODO revisit after https://github.com/JuliaGPU/CUDA.jl/pull/1152
147+
# Revisit after
148+
# https://github.com/JuliaGPU/CUDA.jl/issues/1113
148149
A, n, m = to_dense(g.graph, T; num_nodes = g.num_nodes, weighted)
149150
else
150151
A, n, m = to_sparse(g.graph, T; num_nodes = g.num_nodes, weighted)
@@ -164,63 +165,101 @@ function Graphs.adjacency_matrix(g::GNNGraph{<:ADJMAT_T}, T::DataType = eltype(g
164165
return dir == :out ? A : A'
165166
end
166167

167-
function _get_edge_weight(g, edge_weight)
168-
if edge_weight === true || edge_weight === nothing
169-
ew = get_edge_weight(g)
170-
elseif edge_weight === false
171-
ew = nothing
172-
elseif edge_weight isa AbstractVector
173-
ew = edge_weight
168+
function ChainRulesCore.rrule(::typeof(adjacency_matrix), g::G, T::DataType;
169+
dir = :out, weighted = true) where {G <: GNNGraph{<:ADJMAT_T}}
170+
A = adjacency_matrix(g, T; dir, weighted)
171+
if !weighted
172+
function adjacency_matrix_pullback_noweight(Δ)
173+
return (NoTangent(), ZeroTangent(), NoTangent())
174+
end
175+
return A, adjacency_matrix_pullback_noweight
174176
else
175-
error("Invalid edge_weight argument.")
177+
function adjacency_matrix_pullback_weighted(Δ)
178+
dg = Tangent{G}(; graph = Δ .* binarize(A))
179+
return (NoTangent(), dg, NoTangent())
180+
end
181+
return A, adjacency_matrix_pullback_weighted
182+
end
183+
end
184+
185+
function ChainRulesCore.rrule(::typeof(adjacency_matrix), g::G, T::DataType;
186+
dir = :out, weighted = true) where {G <: GNNGraph{<:COO_T}}
187+
A = adjacency_matrix(g, T; dir, weighted)
188+
w = get_edge_weight(g)
189+
if !weighted || w === nothing
190+
function adjacency_matrix_pullback_noweight(Δ)
191+
return (NoTangent(), ZeroTangent(), NoTangent())
192+
end
193+
return A, adjacency_matrix_pullback_noweight
194+
else
195+
function adjacency_matrix_pullback_weighted(Δ)
196+
s, t = edge_index(g)
197+
dg = Tangent{G}(; graph = (NoTangent(), NoTangent(), NNlib.gather(Δ, s, t)))
198+
return (NoTangent(), dg, NoTangent())
199+
end
200+
return A, adjacency_matrix_pullback_weighted
201+
end
202+
end
203+
204+
function _get_edge_weight(g, edge_weight::Bool)
205+
if edge_weight === true
206+
return get_edge_weight(g)
207+
elseif edge_weight === false
208+
return nothing
176209
end
177-
return ew
178210
end
179211

212+
_get_edge_weight(g, edge_weight::AbstractVector) = edge_weight
213+
180214
"""
181215
degree(g::GNNGraph, T=nothing; dir=:out, edge_weight=true)
182216
183217
Return a vector containing the degrees of the nodes in `g`.
184218
219+
The gradient is propagated through this function only if `edge_weight` is `true`
220+
or a vector.
221+
185222
# Arguments
223+
186224
- `g`: A graph.
187225
- `T`: Element type of the returned vector. If `nothing`, is
188226
chosen based on the graph type and will be an integer
189-
if `edge_weight=false`.
227+
if `edge_weight=false`. Default `nothing`.
190228
- `dir`: For `dir=:out` the degree of a node is counted based on the outgoing edges.
191229
For `dir=:in`, the ingoing edges are used. If `dir=:both` we have the sum of the two.
192230
- `edge_weight`: If `true` and the graph contains weighted edges, the degree will
193231
be weighted. Set to `false` instead to just count the number of
194-
outgoing/ingoing edges.
195-
In alternative, you can also pass a vector of weights to be used
232+
outgoing/ingoing edges.
233+
Finally, you can also pass a vector of weights to be used
196234
instead of the graph's own weights.
235+
Default `true`.
236+
197237
"""
198238
function Graphs.degree(g::GNNGraph{<:COO_T}, T::TT = nothing; dir = :out,
199239
edge_weight = true) where {
200240
TT <: Union{Nothing, Type{<:Number}}}
201241
s, t = edge_index(g)
202242

203-
edge_weight = _get_edge_weight(g, edge_weight)
204-
edge_weight = edge_weight === nothing ? ones_like(s) : edge_weight
205-
206-
T = isnothing(T) ? eltype(edge_weight) : T
207-
degs = fill!(similar(s, T, g.num_nodes), 0)
208-
209-
if dir [:out, :both]
210-
degs = degs .+ NNlib.scatter(+, edge_weight, s, dstsize = (g.num_nodes,))
211-
end
212-
if dir [:in, :both]
213-
degs = degs .+ NNlib.scatter(+, edge_weight, t, dstsize = (g.num_nodes,))
214-
end
215-
return degs
243+
ew = _get_edge_weight(g, edge_weight)
244+
245+
T = if isnothing(T)
246+
if !isnothing(ew)
247+
eltype(ew)
248+
else
249+
eltype(s)
250+
end
251+
else
252+
T
253+
end
254+
return _degree((s, t), T, dir, ew, g.num_nodes)
216255
end
217256

218257
# TODO:: Make efficient
219258
Graphs.degree(g::GNNGraph, i::Union{Int, AbstractVector}; dir = :out) = degree(g; dir)[i]
220259

221260
function Graphs.degree(g::GNNGraph{<:ADJMAT_T}, T::TT = nothing; dir = :out,
222-
edge_weight = true) where {TT}
223-
TT <: Union{Nothing, Type{<:Number}}
261+
edge_weight = true) where {TT<:Union{Nothing, Type{<:Number}}}
262+
224263
# edge_weight=true or edge_weight=nothing act the same here
225264
@assert !(edge_weight isa AbstractArray) "passing the edge weights is not support by adjacency matrix representations"
226265
@assert dir (:in, :out, :both)
@@ -234,6 +273,26 @@ function Graphs.degree(g::GNNGraph{<:ADJMAT_T}, T::TT = nothing; dir = :out,
234273
end
235274
end
236275
A = adjacency_matrix(g)
276+
return _degree(A, T, dir, edge_weight, g.num_nodes)
277+
end
278+
279+
function _degree((s, t)::Tuple, T::Type, dir::Symbol, edge_weight::Nothing, num_nodes::Int)
280+
_degree((s, t), T, dir, ones_like(s, T), num_nodes)
281+
end
282+
283+
function _degree((s, t)::Tuple, T::Type, dir::Symbol, edge_weight::AbstractVector, num_nodes::Int)
284+
degs = fill!(similar(s, T, num_nodes), 0)
285+
286+
if dir [:out, :both]
287+
degs = degs .+ NNlib.scatter(+, edge_weight, s, dstsize = (num_nodes,))
288+
end
289+
if dir [:in, :both]
290+
degs = degs .+ NNlib.scatter(+, edge_weight, t, dstsize = (num_nodes,))
291+
end
292+
return degs
293+
end
294+
295+
function _degree(A::AbstractMatrix, T::Type, dir::Symbol, edge_weight::Bool, num_nodes::Int)
237296
if edge_weight === false
238297
A = binarize(A)
239298
end
@@ -243,6 +302,40 @@ function Graphs.degree(g::GNNGraph{<:ADJMAT_T}, T::TT = nothing; dir = :out,
243302
vec(sum(A, dims = 1)) .+ vec(sum(A, dims = 2))
244303
end
245304

305+
function ChainRulesCore.rrule(::typeof(_degree), graph, T, dir, edge_weight::Nothing, num_nodes)
306+
degs = _degree(graph, T, dir, edge_weight, num_nodes)
307+
function _degree_pullback(Δ)
308+
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent())
309+
end
310+
return degs, _degree_pullback
311+
end
312+
313+
function ChainRulesCore.rrule(::typeof(_degree), A::ADJMAT_T, T, dir, edge_weight::Bool, num_nodes)
314+
degs = _degree(A, T, dir, edge_weight, num_nodes)
315+
if edge_weight === false
316+
function _degree_pullback_noweights(Δ)
317+
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent())
318+
end
319+
return degs, _degree_pullback_noweights
320+
else
321+
function _degree_pullback_weights(Δ)
322+
# We propagate the gradient only to the non-zero elements
323+
# of the adjacency matrix.
324+
bA = binarize(A)
325+
if dir == :in
326+
dA = bA .* Δ'
327+
elseif dir == :out
328+
dA = Δ .* bA
329+
else # dir == :both
330+
dA = Δ .* bA + Δ' .* bA
331+
end
332+
return (NoTangent(), dA, NoTangent(), NoTangent(), NoTangent(), NoTangent())
333+
end
334+
return degs, _degree_pullback_weights
335+
end
336+
end
337+
338+
246339
"""
247340
has_isolated_nodes(g::GNNGraph; dir=:out)
248341

src/layers/conv.jl

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,17 @@ and optionally an edge weight vector.
2626
- `init`: Weights' initializer. Default `glorot_uniform`.
2727
- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`.
2828
- `use_edge_weight`: If `true`, consider the edge weights in the input graph (if available).
29-
If `add_self_loops=true` the new weights will be set to 1. Default `false`.
29+
If `add_self_loops=true` the new weights will be set to 1.
30+
This option is ignored if the `edge_weight` is explicitly provided in the forward pass.
31+
Default `false`.
32+
33+
# Forward
34+
35+
(::GCNConv)(g::GNNGraph, x::AbstractMatrix, edge_weight = nothing) -> AbstractMatrix
36+
37+
Takes as input a graph `g`,ca node feature matrix `x` of size `[in, num_nodes]`,
38+
and optionally an edge weight vector. Returns a node feature matrix of size
39+
`[out, num_nodes]`.
3040
3141
# Examples
3242
@@ -107,7 +117,11 @@ function (l::GCNConv)(g::GNNGraph,
107117
# multiply before convolution if it is more convenient, otherwise multiply after
108118
x = l.weight * x
109119
end
110-
d = degree(g, T; dir = :in, edge_weight)
120+
if edge_weight !== nothing
121+
d = degree(g, T; dir = :in, edge_weight)
122+
else
123+
d = degree(g, T; dir = :in, edge_weight = l.use_edge_weight)
124+
end
111125
c = 1 ./ sqrt.(d)
112126
x = x .* c'
113127
if edge_weight !== nothing
@@ -1288,7 +1302,11 @@ function (l::SGConv)(g::GNNGraph, x::AbstractMatrix{T},
12881302
if Dout < Din
12891303
x = l.weight * x
12901304
end
1291-
d = degree(g, T; dir = :in, edge_weight)
1305+
if edge_weight !== nothing
1306+
d = degree(g, T; dir = :in, edge_weight)
1307+
else
1308+
d = degree(g, T; dir = :in, edge_weight=l.use_edge_weight)
1309+
end
12921310
c = 1 ./ sqrt.(d)
12931311
for iter in 1:(l.k)
12941312
x = x .* c'

test/GNNGraphs/convert.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,8 @@ if TEST_GPU
1010
@test Array(y) [2, 2, 2, 2, 2, 2]
1111

1212
s, t = get_st(A)
13-
@test s isa CuVector
14-
@test t isa CuVector
15-
@test_broken s isa CuVector{Int32}
16-
@test_broken t isa CuVector{Int32}
13+
@test s isa CuVector{<:Integer}
14+
@test t isa CuVector{<:Integer}
1715
@test Array(s) == [2, 3, 1, 3, 1, 2]
1816
@test Array(t) == [1, 1, 2, 2, 3, 3]
1917

test/GNNGraphs/generate.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@
2222

2323
g2 = rand_graph(n, m, bidirected = false, seed = 17, graph_type = GRAPH_T)
2424
@test edge_index(g2) == edge_index(g)
25+
26+
ew = rand(m2)
27+
g = rand_graph(n, m, bidirected = true, seed = 17, graph_type = GRAPH_T, edge_weight = ew)
28+
@test get_edge_weight(g) == [ew; ew] broken=(GRAPH_T != :coo)
29+
30+
ew = rand(m)
31+
g = rand_graph(n, m, bidirected = false, seed = 17, graph_type = GRAPH_T, edge_weight = ew)
32+
@test get_edge_weight(g) == ew broken=(GRAPH_T != :coo)
2533
end
2634

2735
@testset "knn_graph" begin

test/GNNGraphs/gnnheterograph.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,4 +90,3 @@ end
9090
# @test sprint(show, MIME("text/plain"), hg3; context=:compact => false) =="GNNHeteroGraph:\n num_nodes: (:A => 10, :B => 20)\n num_edges: ((:A, :rel1, :B) => 20, (:B, :rel2, :A) => 30)\n ndata:\n\t:A => (x = 2×10 Matrix{Float64}, y = 3×10 Matrix{Float64})\n\t:B => x = 10×20 Matrix{Float64}"
9191
# @test sprint(show, MIME("text/plain"), hg2; context=:compact => false) != sprint(show, MIME("text/plain"), hg3; context=:compact => false)
9292
# end
93-

0 commit comments

Comments
 (0)