Skip to content

Commit dd4a54c

Browse files
differentiable adjacency_matrix and degree (#123)
* differentiable adjacency_matrix for coo * fixes for degree * differentiable adjacency_matrix for dense * fix some cuda problems * add binarize
1 parent 309b88e commit dd4a54c

File tree

12 files changed

+203
-91
lines changed

12 files changed

+203
-91
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
*.jl.mem
44
Manifest.toml
55
/docs/build/
6+
.vscode

src/GNNGraphs/convert.jl

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,16 @@ function to_coo(A::SPARSE_T; dir=:out, num_nodes=nothing, weighted=true)
2929
return (s, t, v), num_nodes, num_edges
3030
end
3131

32-
function to_coo(A::ADJMAT_T; dir=:out, num_nodes=nothing, weighted=true)
32+
function _findnz_idx(A)
3333
nz = findall(!=(0), A) # vec of cartesian indexes
3434
s, t = ntuple(i -> map(t->t[i], nz), 2)
35+
return s, t, nz
36+
end
37+
38+
@non_differentiable _findnz_idx(A)
39+
40+
function to_coo(A::ADJMAT_T; dir=:out, num_nodes=nothing, weighted=true)
41+
s, t, nz = _findnz_idx(A)
3542
v = A[nz]
3643
if dir == :in
3744
s, t = t, s
@@ -115,16 +122,24 @@ function to_dense(coo::COO_T, T=nothing; dir=:out, num_nodes=nothing, weighted=t
115122
# The output will always be a adjmat in :out format (e.g. A[i,j] denotes from i to j)
116123
s, t, val = coo
117124
n::Int = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes
118-
val = isnothing(val) ? eltype(s)(1) : val
119-
T = T === nothing ? eltype(val) : T
120-
if !weighted
121-
val = T(1)
125+
if T === nothing
126+
T = isnothing(val) ? eltype(s) : eltype(val)
122127
end
123-
A = fill!(similar(s, T, (n, n)), 0)
124-
v = vec(A) # vec view of A
128+
if val === nothing || !weighted
129+
val = ones_like(s, T)
130+
end
131+
if eltype(val) != T
132+
val = T.(val)
133+
end
134+
125135
idxs = s .+ n .* (t .- 1)
136+
137+
## using scatter instead of indexing since there could be multiple edges
138+
# A = fill!(similar(s, T, (n, n)), 0)
139+
# v = vec(A) # vec view of A
126140
# A[idxs] .= val # exploiting linear indexing
127-
NNlib.scatter!(+, v, val, idxs) # using scatter instead of indexing since there could be multiple edges
141+
v = NNlib.scatter(+, val, idxs, dstsize=n^2)
142+
A = reshape(v, (n, n))
128143
return A, n, length(s)
129144
end
130145

@@ -172,7 +187,3 @@ function to_sparse(coo::COO_T, T=nothing; dir=:out, num_nodes=nothing, weighted=
172187
end
173188
return A, num_nodes, num_edges
174189
end
175-
176-
@non_differentiable to_coo(x...)
177-
@non_differentiable to_dense(x...)
178-
@non_differentiable to_sparse(x...)

src/GNNGraphs/gatherscatter.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,61 @@ _scatter(aggr, m::NamedTuple, t; dstsize=nothing) = map(m -> _scatter(aggr, m, t
77
_scatter(aggr, m::Tuple, t; dstsize=nothing) = map(m -> _scatter(aggr, m, t; dstsize), m)
88
_scatter(aggr, m::AbstractArray, t; dstsize=nothing) = NNlib.scatter(aggr, m, t; dstsize)
99
_scatter(aggr, m::Nothing, t; dstsize=nothing) = nothing
10+
11+
## TO MOVE TO NNlib ######################################################
12+
13+
14+
### Considers the src a zero dimensional object.
15+
### Useful for implementing `StatsBase.counts`, `degree`, etc...
16+
### function NNlib.scatter!(op, dst::AbstractArray, src::Number, idx::AbstractArray)
17+
### for k in CartesianIndices(idx)
18+
### # dst_v = NNlib._view(dst, idx[k])
19+
### # dst_v .= (op).(dst_v, src)
20+
### dst[idx[k]] .= (op).(dst[idx[k]], src)
21+
### end
22+
### dst
23+
### end
24+
25+
# 10 times faster than the generic version above.
26+
# All the speedup comes from not broadcasting `op`, i dunno why.
27+
# function NNlib.scatter!(op, dst::AbstractVector, src::Number, idx::AbstractVector{<:Integer})
28+
# for i in idx
29+
# dst[i] = op(dst[i], src)
30+
# end
31+
# end
32+
33+
## NNlib._view(X, k) = view(X, k...)
34+
## NNlib._view(X, k::Union{Integer, CartesianIndex}) = view(X, k)
35+
#
36+
## Considers src as a zero dimensional object to be scattered
37+
## function NNlib.scatter(op,
38+
## src::Tsrc,
39+
## idx::AbstractArray{Tidx,Nidx};
40+
## init = nothing, dstsize = nothing) where {Tsrc<:Number,Tidx,Nidx}
41+
## dstsz = isnothing(dstsize) ? maximum_dims(idx) : dstsize
42+
## dst = similar(src, Tsrc, dstsz)
43+
## xinit = isnothing(init) ? scatter_empty(op, Tsrc) : init
44+
## fill!(dst, xinit)
45+
## scatter!(op, dst, src, idx)
46+
## end
47+
48+
# function scatter_scalar_kernel!(op, dst, src, idx)
49+
# index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
50+
51+
# @inbounds if index <= length(idx)
52+
# CUDA.@atomic dst[idx[index]...] = op(dst[idx[index]...], src)
53+
# end
54+
# return nothing
55+
# end
56+
57+
# function NNlib.scatter!(op, dst::AnyCuArray, src::Number, idx::AnyCuArray)
58+
# max_idx = length(idx)
59+
# args = op, dst, src, idx
60+
61+
# kernel = @cuda launch=false scatter_scalar_kernel!(args...)
62+
# config = launch_configuration(kernel.fun; max_threads=256)
63+
# threads = min(max_idx, config.threads)
64+
# blocks = cld(max_idx, threads)
65+
# kernel(args...; threads=threads, blocks=blocks)
66+
# return dst
67+
# end

src/GNNGraphs/query.jl

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,6 @@ end
127127
# return [fneighs(g, i) for i in nodes]
128128
# end
129129

130-
131-
132130
adjacency_list(g::GNNGraph; dir=:out) = adjacency_list(g, 1:g.num_nodes; dir)
133131

134132

@@ -159,7 +157,7 @@ function Graphs.adjacency_matrix(g::GNNGraph{<:ADJMAT_T}, T::DataType=eltype(g);
159157
@assert dir [:in, :out]
160158
A = g.graph
161159
if !weighted
162-
A = map(>(0), A)
160+
A = binarize(A)
163161
end
164162
A = T != eltype(A) ? T.(A) : A
165163
return dir == :out ? A : A'
@@ -201,15 +199,16 @@ function Graphs.degree(g::GNNGraph{<:COO_T}, T::TT=nothing; dir=:out, edge_weigh
201199
s, t = edge_index(g)
202200

203201
edge_weight = _get_edge_weight(g, edge_weight)
204-
edge_weight = edge_weight === nothing ? eltype(s)(1) : edge_weight
202+
edge_weight = edge_weight === nothing ? ones_like(s) : edge_weight
205203

206204
T = isnothing(T) ? eltype(edge_weight) : T
207205
degs = fill!(similar(s, T, g.num_nodes), 0)
206+
208207
if dir [:out, :both]
209-
NNlib.scatter!(+, degs, edge_weight, s)
208+
degs = degs .+ NNlib.scatter(+, edge_weight, s, dstsize=(g.num_nodes,))
210209
end
211210
if dir [:in, :both]
212-
NNlib.scatter!(+, degs, edge_weight, t)
211+
degs = degs .+ NNlib.scatter(+, edge_weight, t, dstsize=(g.num_nodes,))
213212
end
214213
return degs
215214
end
@@ -233,7 +232,7 @@ function Graphs.degree(g::GNNGraph{<:ADJMAT_T}, T::TT=nothing; dir=:out, edge_we
233232
end
234233
A = adjacency_matrix(g)
235234
if edge_weight === false
236-
A = map(>(0), A)
235+
A = binarize(A)
237236
end
238237
A = eltype(A) != T ? T.(A) : A
239238
return dir == :out ? vec(sum(A, dims=2)) :
@@ -394,14 +393,12 @@ function has_multi_edges(g::GNNGraph)
394393
length(union(idxs)) < length(idxs)
395394
end
396395

397-
396+
@non_differentiable edge_index(x...)
398397
@non_differentiable adjacency_list(x...)
399-
@non_differentiable adjacency_matrix(x...)
400-
@non_differentiable degree(x...)
401398
@non_differentiable graph_indicator(x...)
402399
@non_differentiable has_multi_edges(x...)
403400
@non_differentiable Graphs.has_self_loops(x...)
404401
@non_differentiable is_bidirected(x...)
405-
@non_differentiable normalized_adjacency(x...)
406-
@non_differentiable normalized_laplacian(x...)
407-
@non_differentiable scaled_laplacian(x...)
402+
@non_differentiable normalized_adjacency(x...) # TODO remove this in the future
403+
@non_differentiable normalized_laplacian(x...) # TODO remove this in the future
404+
@non_differentiable scaled_laplacian(x...) # TODO remove this in the future

src/GNNGraphs/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ function edge_decoding(idx, n; directed=true)
149149
return s, t
150150
end
151151

152+
binarize(x) = map(>(0), x)
153+
154+
@non_differentiable binarize(x...)
152155
@non_differentiable edge_encoding(x...)
153156
@non_differentiable edge_decoding(x...)
154157

src/layers/conv.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,14 @@ function (l::GCNConv)(g::GNNGraph, x::AbstractMatrix{T}, edge_weight::EW=nothing
7979

8080
@assert !(g isa GNNGraph{<:ADJMAT_T} && edge_weight !== nothing) "Providing external edge_weight is not yet supported for adjacency matrix graphs"
8181

82+
if edge_weight !== nothing
83+
@assert length(edge_weight) == g.num_edges "Wrong number of edge weights (expected $(g.num_edges) but given $(length(edge_weight)))"
84+
end
85+
8286
if l.add_self_loops
8387
g = add_self_loops(g)
8488
if edge_weight !== nothing
89+
# Pad weights with ones
8590
# TODO for ADJMAT_T the new edges are not generally at the end
8691
edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)]
8792
@assert length(edge_weight) == g.num_edges

src/utils.jl

Lines changed: 13 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,5 @@
11
ofeltype(x, y) = convert(float(eltype(x)), y)
22

3-
# Considers the src a zero dimensional object.
4-
# Useful for implementing `StatsBase.counts`, `degree`, etc...
5-
# function NNlib.scatter!(op, dst::AbstractArray, src::Number, idx::AbstractArray)
6-
# for k in CartesianIndices(idx)
7-
# # dst_v = NNlib._view(dst, idx[k])
8-
# # dst_v .= (op).(dst_v, src)
9-
# dst[idx[k]] .= (op).(dst[idx[k]], src)
10-
# end
11-
# dst
12-
# end
13-
14-
# 10 time faster than the generic version above.
15-
# All the speedup comes from not broadcasting `op`, i dunno why.
16-
function NNlib.scatter!(op, dst::AbstractVector, src::Number, idx::AbstractVector{<:Integer})
17-
for i in idx
18-
dst[i] = op(dst[i], src)
19-
end
20-
end
21-
22-
# NNlib._view(X, k) = view(X, k...)
23-
# NNlib._view(X, k::Union{Integer, CartesianIndex}) = view(X, k)
24-
25-
# Considers src as a zero dimensional object to be scattered
26-
# function NNlib.scatter(op,
27-
# src::Tsrc,
28-
# idx::AbstractArray{Tidx,Nidx};
29-
# init = nothing, dstsize = nothing) where {Tsrc<:Number,Tidx,Nidx}
30-
31-
# dstsz = isnothing(dstsize) ? maximum_dims(idx) : dstsize
32-
# dst = similar(src, Tsrc, dstsz)
33-
# xinit = isnothing(init) ? scatter_empty(op, Tsrc) : init
34-
# fill!(dst, xinit)
35-
# scatter!(op, dst, src, idx)
36-
# end
37-
38-
39-
function scatter_scalar_kernel!(op, dst, src, idx)
40-
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
41-
42-
@inbounds if index <= length(idx)
43-
CUDA.@atomic dst[idx[index]...] = op(dst[idx[index]...], src)
44-
end
45-
return nothing
46-
end
47-
48-
function NNlib.scatter!(op, dst::AnyCuArray, src::Number, idx::AnyCuArray)
49-
max_idx = length(idx)
50-
args = op, dst, src, idx
51-
52-
kernel = @cuda launch=false scatter_scalar_kernel!(args...)
53-
config = launch_configuration(kernel.fun; max_threads=256)
54-
threads = min(max_idx, config.threads)
55-
blocks = cld(max_idx, threads)
56-
kernel(args...; threads=threads, blocks=blocks)
57-
return dst
58-
end
59-
603
"""
614
reduce_nodes(aggr, g, x)
625
@@ -157,3 +100,16 @@ function broadcast_edges(g::GNNGraph, x)
157100
return gather(x, gi)
158101
end
159102

103+
# More generic version of
104+
# https://github.com/JuliaDiff/ChainRules.jl/pull/586
105+
# This applies to all arrays
106+
# Withouth this, gradient of T.(A) for A dense gpu matrix errors.
107+
function ChainRulesCore.rrule(::typeof(Broadcast.broadcasted), T::Type{<:Number}, x::AbstractArray)
108+
proj = ProjectTo(x)
109+
110+
function broadcasted_cast(Δ)
111+
return NoTangent(), NoTangent(), proj(Δ)
112+
end
113+
114+
return T.(x), broadcasted_cast
115+
end

test/GNNGraphs/convert.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
if TEST_GPU
2+
@testset "to_coo(dense) on gpu" begin
3+
get_st(A) = GNNGraphs.to_coo(A)[1][1:2]
4+
get_val(A) = GNNGraphs.to_coo(A)[1][3]
5+
6+
A = cu([0 2 2; 2. 0 2; 2 2 0])
7+
8+
y = get_val(A)
9+
@test y isa CuVector{Float32}
10+
@test Array(y) [2, 2, 2, 2, 2, 2]
11+
12+
s, t = get_st(A)
13+
@test s isa CuVector
14+
@test t isa CuVector
15+
@test_broken s isa CuVector{Int32}
16+
@test_broken t isa CuVector{Int32}
17+
@test Array(s) == [2, 3, 1, 3, 1, 2]
18+
@test Array(t) == [1, 1, 2, 2, 3, 3]
19+
20+
@test gradient(A -> sum(get_val(A)), A)[1] isa CuMatrix{Float32}
21+
end
22+
end

test/GNNGraphs/query.jl

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,16 +71,39 @@
7171
end
7272
@test eltype(d) <: Integer
7373
if GRAPH_T == :coo
74+
# TODO use the @test option broken = (GRAPH_T != :coo) on julia >= 1.7
7475
@test degree(g, edge_weight=2*eweight) == [4.4, 2.4, 2.0, 0.0]
76+
else
77+
@test_broken degree(g, edge_weight=2*eweight) == [4.4, 2.4, 2.0, 0.0]
7578
end
76-
79+
7780
if TEST_GPU
7881
g_gpu = g |> gpu
7982
d = degree(g)
8083
d_gpu = degree(g_gpu)
8184
@test d_gpu isa CuVector{Float32}
8285
@test Array(d_gpu) d
8386
end
87+
@testset "gradient" begin
88+
gw = gradient(eweight) do w
89+
g = GNNGraph((s, t, w), graph_type=GRAPH_T)
90+
sum(degree(g, edge_weight=false))
91+
end[1]
92+
93+
@test gw === nothing
94+
95+
gw = gradient(eweight) do w
96+
g = GNNGraph((s, t, w), graph_type=GRAPH_T)
97+
sum(degree(g, edge_weight=true))
98+
end[1]
99+
100+
if GRAPH_T == :sparse
101+
@test_broken gw isa Vector{Float64}
102+
@test gw isa AbstractVector{Float64}
103+
else
104+
@test gw isa Vector{Float64}
105+
end
106+
end
84107
end
85108
end
86109

@@ -105,5 +128,25 @@
105128
Abin = adjacency_matrix(g, Float32, weighted=false)
106129
@test Abin abin
107130
@test eltype(Abin) == Float32
131+
132+
@testset "gradient" begin
133+
s = [1,2,3]
134+
t = [2,3,1]
135+
w = [0.1,0.1,0.2]
136+
gw = gradient(w) do w
137+
g = GNNGraph(s, t, w, graph_type=GRAPH_T)
138+
A = adjacency_matrix(g, weighted=false)
139+
sum(A)
140+
end[1]
141+
@test gw === nothing
142+
143+
gw = gradient(w) do w
144+
g = GNNGraph(s, t, w, graph_type=GRAPH_T)
145+
A = adjacency_matrix(g, weighted=true)
146+
sum(A)
147+
end[1]
148+
149+
@test gw == [1,1,1]
150+
end
108151
end
109152
end

0 commit comments

Comments
 (0)