Skip to content

Commit 56f9c5e

Browse files
committed
Added propagate copy_xj CUDA sparse support using matrix mul
1 parent c707e2e commit 56f9c5e

File tree

2 files changed

+1
-9
lines changed

2 files changed

+1
-9
lines changed

GNNlib/ext/GNNlibCUDAExt.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,6 @@ using GNNGraphs: GNNGraph, COO_T, SPARSE_T
77

88
###### PROPAGATE SPECIALIZATIONS ####################
99

10-
## COPY_XJ
11-
12-
## avoid the fast path on gpu until we have better cuda support
13-
function GNNlib.propagate(::typeof(copy_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
14-
xi, xj::AnyCuMatrix, e)
15-
propagate((xi, xj, e) -> copy_xj(xi, xj, e), g, +, xi, xj, e)
16-
end
17-
1810
## E_MUL_XJ
1911

2012
## avoid the fast path on gpu until we have better cuda support

GNNlib/src/msgpass.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ end
213213
## COPY_XJ
214214

215215
function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix, e)
216-
A = adjacency_matrix(g, weighted = false)
216+
A = adjacency_matrix(g, eltype(xj); weighted = false)
217217
return xj * A
218218
end
219219

0 commit comments

Comments
 (0)