Skip to content

Commit b0e5dd9

Browse files
use GNNlib in GNN.jl
1 parent 80c672a commit b0e5dd9

File tree

7 files changed

+49
-435
lines changed

7 files changed

+49
-435
lines changed
Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,37 @@
1+
module GNNlibCUDAExt
2+
3+
using CUDA
4+
using Random, Statistics, LinearAlgebra
5+
using GNNlib: GNNlib, propagate, copy_xj, e_mul_xj, w_mul_xj
6+
using GNNGraphs: GNNGraph, COO_T, SPARSE_T
17

28
###### PROPAGATE SPECIALIZATIONS ####################
39

410
## COPY_XJ
511

612
## avoid the fast path on gpu until we have better cuda support
7-
function propagate(::typeof(copy_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
8-
xi, xj::AnyCuMatrix, e)
13+
function GNNlib.propagate(::typeof(copy_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
14+
xi, xj::AnyCuMatrix, e)
915
propagate((xi, xj, e) -> copy_xj(xi, xj, e), g, +, xi, xj, e)
1016
end
1117

1218
## E_MUL_XJ
1319

1420
## avoid the fast path on gpu until we have better cuda support
15-
function propagate(::typeof(e_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
16-
xi, xj::AnyCuMatrix, e::AbstractVector)
21+
function GNNlib.propagate(::typeof(e_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
22+
xi, xj::AnyCuMatrix, e::AbstractVector)
1723
propagate((xi, xj, e) -> e_mul_xj(xi, xj, e), g, +, xi, xj, e)
1824
end
1925

2026
## W_MUL_XJ
2127

2228
## avoid the fast path on gpu until we have better cuda support
23-
function propagate(::typeof(w_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
24-
xi, xj::AnyCuMatrix, e::Nothing)
29+
function GNNlib.propagate(::typeof(w_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
30+
xi, xj::AnyCuMatrix, e::Nothing)
2531
propagate((xi, xj, e) -> w_mul_xj(xi, xj, e), g, +, xi, xj, e)
2632
end
2733

28-
# function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(mean), xi, xj::AbstractMatrix, e)
34+
# function GNNlib.propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(mean), xi, xj::AbstractMatrix, e)
2935
# A = adjacency_matrix(g, weighted=false)
3036
# D = compute_degree(A)
3137
# return xj * A * D
@@ -35,3 +41,5 @@ end
3541
# compute_degree(A) = Diagonal(1f0 ./ vec(sum(A; dims=2)))
3642

3743
# Flux.Zygote.@nograd compute_degree
44+
45+
end #module

GNNlib/ext/GNNlibCUDAExt/GNNlibCUDAExt.jl

Lines changed: 0 additions & 11 deletions
This file was deleted.

Project.toml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@ DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
99
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1010
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1111
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
12+
GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48"
1213
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
13-
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1414
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
15+
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1516
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1617
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1718
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
@@ -20,8 +21,8 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2021
[weakdeps]
2122
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2223

23-
[extensions]
24-
GraphNeuralNetworksCUDAExt = "CUDA"
24+
# [extensions]
25+
# GraphNeuralNetworksCUDAExt = "CUDA"
2526

2627
[compat]
2728
CUDA = "4, 5"
@@ -30,9 +31,10 @@ DataStructures = "0.18"
3031
Flux = "0.14"
3132
Functors = "0.4.1"
3233
GNNGraphs = "1.0"
34+
GNNlib = "0.2"
3335
LinearAlgebra = "1"
34-
MacroTools = "0.5"
3536
MLUtils = "0.4"
37+
MacroTools = "0.5"
3638
NNlib = "0.9"
3739
Random = "1"
3840
Reexport = "1"

src/GraphNeuralNetworks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ using ChainRulesCore
1111
using Reexport
1212
using DataStructures: nlargest
1313
using MLUtils: zeros_like
14+
using GNNlib: GNNlib
1415

1516
@reexport using GNNGraphs
1617
using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T,

src/deprecations.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11

2-
@deprecate AGNNConv(init_beta) AGNNConv(; init_beta)
2+
# V1.0 deprecations
3+
# TODO doe some reason this is not working
4+
# @deprecate (l::GCNConv)(g, x, edge_weight, norm_fn; conv_weight=nothing) l(g, x, edge_weight; norm_fn, conv_weight)

0 commit comments

Comments
 (0)