Skip to content

Commit 0221593

Browse files
authored
CUDA copy_xj propagate sparse support and benchmarks (#605)
* Added propagate copy_xj CUDA sparse support using matrix mul * Add benchmarks for CUDA sparse propagate copy_xj * Keep old gather/scatter implementation for COO_T
1 parent c707e2e commit 0221593

File tree

4 files changed

+54
-2
lines changed

4 files changed

+54
-2
lines changed

GNNlib/ext/GNNlibCUDAExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using GNNGraphs: GNNGraph, COO_T, SPARSE_T
1010
## COPY_XJ
1111

1212
## avoid the fast path on gpu until we have better cuda support
13-
function GNNlib.propagate(::typeof(copy_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
13+
function GNNlib.propagate(::typeof(copy_xj), g::GNNGraph{COO_T}, ::typeof(+),
1414
xi, xj::AnyCuMatrix, e)
1515
propagate((xi, xj, e) -> copy_xj(xi, xj, e), g, +, xi, xj, e)
1616
end

GNNlib/src/msgpass.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ end
213213
## COPY_XJ
214214

215215
function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix, e)
216-
A = adjacency_matrix(g, weighted = false)
216+
A = adjacency_matrix(g, eltype(xj); weighted = false)
217217
return xj * A
218218
end
219219

GraphNeuralNetworks/perf/Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
[deps]
22
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
33
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
4+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
5+
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
6+
GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48"
47
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
8+
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
59
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
610
Graphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# # Activate the perf environment
2+
# using Pkg
3+
# Pkg.activate(@__DIR__)
4+
# Pkg.develop(path=joinpath(@__DIR__, "..", "..", "GNNGraphs"))
5+
# Pkg.develop(path=joinpath(@__DIR__, "..", "..", "GNNlib"))
6+
# Pkg.develop(path=joinpath(@__DIR__, ".."))
7+
# Pkg.instantiate()
8+
using SparseArrays
9+
using GraphNeuralNetworks
10+
using BenchmarkTools
11+
import Random: seed!
12+
using LinearAlgebra
13+
using Flux, CUDA
14+
15+
# ENV["JULIA_DEBUG"] = "GraphNeuralNetworks,GNNlib,GNNlibCUDAExt,GNNGraphs,GNNGraphsCUDAExt,CUDA" # packages with debugging enabled, don't put a whitespace between the package names
16+
17+
function prop_copy_xj(graph_type, sp_p, n, feat_size)
18+
A = sprand(n, n, sp_p)
19+
b = rand(1, n)
20+
B = rand(feat_size, n)
21+
g = GNNGraph(A,
22+
ndata = (; b = b, B = B),
23+
edata = (; A = reshape(A.nzval, 1, :)),
24+
graph_type = graph_type) |> dev
25+
printstyled("propagate copy_xj for graph type: $graph_type", "\n", color=:yellow)
26+
CUDA.@sync propagate(copy_xj, g, +; xj = g.ndata.B) # run once to compile before benchmarking
27+
# @profview for _ in 1:1000
28+
# propagate(copy_xj, g, +; xj = g.ndata.B)
29+
# end
30+
@btime CUDA.@sync propagate($copy_xj, $g, +; xj = $g.ndata.B) # using spmm for :sparse
31+
printstyled("gather/scatter propagate copy_xj for graph type: $graph_type", "\n", color=:yellow)
32+
CUDA.@sync propagate((xi, xj, e) -> xj, g, +; xj = g.ndata.B) # run once to compile before benchmarking
33+
@btime CUDA.@sync propagate((xi, xj, e) -> xj, $g, +; xj = $g.ndata.B) # using gather/scatter
34+
return nothing
35+
end
36+
37+
seed!(0)
38+
dev = gpu_device()
39+
println("Device: ", dev)
40+
feat_size = 128
41+
# test for :sparse graph_type
42+
for n in (32, 128, 1024)
43+
for sp_p in (0.01, 0.1, 0.9)
44+
printstyled("n = $n, feat_size = $feat_size, sparsity = $sp_p\n", color=:blue)
45+
prop_copy_xj(:sparse, sp_p, n, feat_size)
46+
println()
47+
end
48+
end

0 commit comments

Comments
 (0)