Skip to content

Commit 92d3163

Browse files
Remove CUDA dependence in favor of extension (#318)
* cuda extension * fix
1 parent f59ce44 commit 92d3163

File tree

14 files changed

+84
-94
lines changed

14 files changed

+84
-94
lines changed

Project.toml

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,14 @@ uuid = "cffab07f-9bc2-4db1-8861-388f63bf7694"
33
authors = ["Carlo Lucibello and contributors"]
44
version = "0.6.8"
55

6+
[weakdeps]
7+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
8+
9+
[extensions]
10+
GraphNeuralNetworksCUDAExt = "CUDA"
11+
612
[deps]
713
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
8-
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
914
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1015
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1116
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
@@ -22,7 +27,6 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2227
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2328
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2429
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
25-
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
2630

2731
[compat]
2832
Adapt = "3"
@@ -46,12 +50,16 @@ julia = "1.9"
4650
[extras]
4751
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
4852
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
53+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
4954
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
5055
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
5156
InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
5257
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
5358
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5459
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
60+
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
5561

5662
[targets]
57-
test = ["Test", "Adapt", "DataFrames", "InlineStrings", "Zygote", "FiniteDifferences", "ChainRulesTestUtils", "MLDatasets"]
63+
test = ["Test", "Adapt", "DataFrames", "InlineStrings", "Zygote",
64+
"FiniteDifferences", "ChainRulesTestUtils", "MLDatasets",
65+
"CUDA", "cuDNN"]
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
2+
GNNGraphs._rand_dense_vector(A::CUMAT_T) = CUDA.randn(size(A, 1))
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
2+
GNNGraphs.dense_zeros_like(a::CUMAT_T, T::Type, sz = size(a)) = CUDA.zeros(T, sz)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
2+
GNNGraphs.iscuarray(x::AnyCuArray) = true
3+
4+
5+
function sort_edge_index(u::AnyCuArray, v::AnyCuArray)
6+
#TODO proper cuda friendly implementation
7+
sort_edge_index(u |> Flux.cpu, v |> Flux.cpu) |> Flux.gpu
8+
end
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
module GraphNeuralNetworksCUDAExt
2+
3+
using CUDA
4+
using Random, Statistics, LinearAlgebra
5+
using GraphNeuralNetworks
6+
using GraphNeuralNetworks.GNNGraphs
7+
using GraphNeuralNetworks.GNNGraphs: COO_T, ADJMAT_T, SPARSE_T
8+
import GraphNeuralNetworks: propagate
9+
10+
const CUMAT_T = Union{CUDA.AnyCuMatrix, CUDA.CUSPARSE.CuSparseMatrix}
11+
12+
include("GNNGraphs/query.jl")
13+
include("GNNGraphs/transform.jl")
14+
include("GNNGraphs/utils.jl")
15+
include("msgpass.jl")
16+
17+
end #module
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
2+
###### PROPAGATE SPECIALIZATIONS ####################
3+
4+
## COPY_XJ
5+
6+
## avoid the fast path on gpu until we have better cuda support
7+
function propagate(::typeof(copy_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
8+
xi, xj::AnyCuMatrix, e)
9+
propagate((xi, xj, e) -> copy_xj(xi, xj, e), g, +, xi, xj, e)
10+
end
11+
12+
## E_MUL_XJ
13+
14+
## avoid the fast path on gpu until we have better cuda support
15+
function propagate(::typeof(e_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
16+
xi, xj::AnyCuMatrix, e::AbstractVector)
17+
propagate((xi, xj, e) -> e_mul_xj(xi, xj, e), g, +, xi, xj, e)
18+
end
19+
20+
## W_MUL_XJ
21+
22+
## avoid the fast path on gpu until we have better cuda support
23+
function propagate(::typeof(w_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
24+
xi, xj::AnyCuMatrix, e::Nothing)
25+
propagate((xi, xj, e) -> w_mul_xj(xi, xj, e), g, +, xi, xj, e)
26+
end
27+
28+
# function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(mean), xi, xj::AbstractMatrix, e)
29+
# A = adjacency_matrix(g, weighted=false)
30+
# D = compute_degree(A)
31+
# return xj * A * D
32+
# end
33+
34+
# # Zygote bug. Error with sparse matrix without nograd
35+
# compute_degree(A) = Diagonal(1f0 ./ vec(sum(A; dims=2)))
36+
37+
# Flux.Zygote.@nograd compute_degree

src/GNNGraphs/GNNGraphs.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ module GNNGraphs
22

33
using SparseArrays
44
using Functors: @functor
5-
using CUDA
65
import Graphs
76
using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree,
87
has_self_loops, is_directed
@@ -15,7 +14,7 @@ import KrylovKit
1514
using ChainRulesCore
1615
using LinearAlgebra, Random, Statistics
1716
import MLUtils
18-
using MLUtils: getobs, numobs
17+
using MLUtils: getobs, numobs, ones_like, zeros_like
1918
import Functors
2019

2120
include("chainrules.jl") # hacks for differentiability

src/GNNGraphs/abstracttypes.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ const COO_T = Tuple{T, T, V} where {T <: AbstractVector{<:Integer}, V}
33
const ADJLIST_T = AbstractVector{T} where {T <: AbstractVector{<:Integer}}
44
const ADJMAT_T = AbstractMatrix
55
const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T
6-
const CUMAT_T = Union{CUDA.AnyCuMatrix, CUDA.CUSPARSE.CuSparseMatrix}
76

87
const AVecI = AbstractVector{<:Integer}
98

src/GNNGraphs/gatherscatter.jl

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -16,60 +16,3 @@ function _scatter(aggr,
1616
dstsize = (size(src)[1:(end - 1)]..., n)
1717
return NNlib.scatter(aggr, src, idx; dstsize)
1818
end
19-
20-
## TO MOVE TO NNlib ######################################################
21-
22-
### Considers the src a zero dimensional object.
23-
### Useful for implementing `StatsBase.counts`, `degree`, etc...
24-
### function NNlib.scatter!(op, dst::AbstractArray, src::Number, idx::AbstractArray)
25-
### for k in CartesianIndices(idx)
26-
### # dst_v = NNlib._view(dst, idx[k])
27-
### # dst_v .= (op).(dst_v, src)
28-
### dst[idx[k]] .= (op).(dst[idx[k]], src)
29-
### end
30-
### dst
31-
### end
32-
33-
# 10 times faster than the generic version above.
34-
# All the speedup comes from not broadcasting `op`, i dunno why.
35-
# function NNlib.scatter!(op, dst::AbstractVector, src::Number, idx::AbstractVector{<:Integer})
36-
# for i in idx
37-
# dst[i] = op(dst[i], src)
38-
# end
39-
# end
40-
41-
## NNlib._view(X, k) = view(X, k...)
42-
## NNlib._view(X, k::Union{Integer, CartesianIndex}) = view(X, k)
43-
#
44-
## Considers src as a zero dimensional object to be scattered
45-
## function NNlib.scatter(op,
46-
## src::Tsrc,
47-
## idx::AbstractArray{Tidx,Nidx};
48-
## init = nothing, dstsize = nothing) where {Tsrc<:Number,Tidx,Nidx}
49-
## dstsz = isnothing(dstsize) ? maximum_dims(idx) : dstsize
50-
## dst = similar(src, Tsrc, dstsz)
51-
## xinit = isnothing(init) ? scatter_empty(op, Tsrc) : init
52-
## fill!(dst, xinit)
53-
## scatter!(op, dst, src, idx)
54-
## end
55-
56-
# function scatter_scalar_kernel!(op, dst, src, idx)
57-
# index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
58-
59-
# @inbounds if index <= length(idx)
60-
# CUDA.@atomic dst[idx[index]...] = op(dst[idx[index]...], src)
61-
# end
62-
# return nothing
63-
# end
64-
65-
# function NNlib.scatter!(op, dst::AnyCuArray, src::Number, idx::AnyCuArray)
66-
# max_idx = length(idx)
67-
# args = op, dst, src, idx
68-
69-
# kernel = @cuda launch=false scatter_scalar_kernel!(args...)
70-
# config = launch_configuration(kernel.fun; max_threads=256)
71-
# threads = min(max_idx, config.threads)
72-
# blocks = cld(max_idx, threads)
73-
# kernel(args...; threads=threads, blocks=blocks)
74-
# return dst
75-
# end

src/GNNGraphs/query.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ If `weighted=true`, the `A` will contain the edge weights if any, otherwise the
181181
"""
182182
function Graphs.adjacency_matrix(g::GNNGraph{<:COO_T}, T::DataType = eltype(g); dir = :out,
183183
weighted = true)
184-
if g.graph[1] isa CuVector
184+
if iscuarray(g.graph[1])
185185
# Revisit after
186186
# https://github.com/JuliaGPU/CUDA.jl/issues/1113
187187
A, n, m = to_dense(g.graph, T; num_nodes = g.num_nodes, weighted)
@@ -448,7 +448,6 @@ function _eigmax(A)
448448
end
449449

450450
_rand_dense_vector(A::AbstractMatrix{T}) where {T} = randn(float(T), size(A, 1))
451-
_rand_dense_vector(A::CUMAT_T) = CUDA.randn(size(A, 1))
452451

453452
# Eigenvalues for cuarray don't seem to be well supported.
454453
# https://github.com/JuliaGPU/CUDA.jl/issues/154

0 commit comments

Comments
 (0)