diff --git a/GNNlib/ext/GNNlibCUDAExt.jl b/GNNlib/ext/GNNlibCUDAExt.jl index afe22c3f0..78ab49262 100644 --- a/GNNlib/ext/GNNlibCUDAExt.jl +++ b/GNNlib/ext/GNNlibCUDAExt.jl @@ -26,7 +26,7 @@ end ## W_MUL_XJ ## avoid the fast path on gpu until we have better cuda support -function GNNlib.propagate(::typeof(w_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+), +function GNNlib.propagate(::typeof(w_mul_xj), g::GNNGraph{COO_T}, ::typeof(+), xi, xj::AnyCuMatrix, e::Nothing) propagate((xi, xj, e) -> w_mul_xj(xi, xj, e), g, +, xi, xj, e) end diff --git a/GNNlib/src/msgpass.jl b/GNNlib/src/msgpass.jl index 7b7685e1b..29aaaa772 100644 --- a/GNNlib/src/msgpass.jl +++ b/GNNlib/src/msgpass.jl @@ -233,7 +233,7 @@ end # for weighted convolution function propagate(::typeof(w_mul_xj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix, e::Nothing) - A = adjacency_matrix(g, weighted = true) + A = adjacency_matrix(g, eltype(xj); weighted = true) return xj * A end diff --git a/GraphNeuralNetworks/perf/Project.toml b/GraphNeuralNetworks/perf/Project.toml index c09a51049..39f4e96fc 100644 --- a/GraphNeuralNetworks/perf/Project.toml +++ b/GraphNeuralNetworks/perf/Project.toml @@ -7,4 +7,3 @@ GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48" GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" -Graphs = "093fc24a-ae57-5d10-9952-331d41423f4d" diff --git a/GraphNeuralNetworks/perf/sparse_propagate_cuda.jl b/GraphNeuralNetworks/perf/sparse_propagate_cuda.jl index fee5372d6..e29e82596 100644 --- a/GraphNeuralNetworks/perf/sparse_propagate_cuda.jl +++ b/GraphNeuralNetworks/perf/sparse_propagate_cuda.jl @@ -34,6 +34,23 @@ function prop_copy_xj(graph_type, sp_p, n, feat_size) return nothing end +function prop_w_mul_xj(graph_type, sp_p, n, feat_size) + A = sprand(n, n, sp_p) + b = rand(1, n) + B = rand(feat_size, n) + g = GNNGraph(A, + ndata = (; b = b, B = B), + edata = (; A = reshape(A.nzval, 1, :)), + graph_type = graph_type) |> dev + printstyled("propagate w_mul_xj for graph type: $graph_type", "\n", color=:yellow) + CUDA.@sync propagate(w_mul_xj, g, +; xj = g.ndata.B) # run once to compile before benchmarking + @btime CUDA.@sync propagate($w_mul_xj, $g, +; xj = $g.ndata.B) # using spmm for :sparse + printstyled("gather/scatter propagate w_mul_xj for graph type: $graph_type", "\n", color=:yellow) + CUDA.@sync propagate((xi, xj, e) -> w_mul_xj(xi, xj, e), g, +; xj = g.ndata.B) # run once to compile before benchmarking + @btime CUDA.@sync propagate((xi, xj, e) -> w_mul_xj(xi, xj, e), $g, +; xj = $g.ndata.B) # using gather/scatter + return nothing +end + seed!(0) dev = gpu_device() println("Device: ", dev)