Skip to content

Commit df734e1

Browse files
committed
Add binarize function for CuSparseMatrixCSC and update its usage
1 parent b85f098 commit df734e1

File tree

5 files changed

+19
-11
lines changed

5 files changed

+19
-11
lines changed

GNNGraphs/ext/GNNGraphsCUDAExt.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using CUDA
44
using Random, Statistics, LinearAlgebra
55
using GNNGraphs
66
using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T
7+
using SparseArrays
78

89
const CUMAT_T = Union{CUDA.AnyCuMatrix, CUDA.CUSPARSE.CuSparseMatrix}
910

@@ -19,6 +20,18 @@ GNNGraphs.dense_zeros_like(a::CUMAT_T, T::Type, sz = size(a)) = CUDA.zeros(T, sz
1920
# Utils
2021

2122
GNNGraphs.iscuarray(x::AnyCuArray) = true
23+
function GNNGraphs.binarize(Mat::CUSPARSE.CuSparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
24+
@debug "Binarizing CuSparseMatrixCSC with type $(Tv) and index type $(Ti)"
25+
bin_vals = fill!(similar(nonzeros(Mat)), one(Tv))
26+
# Binarize a CuSparseMatrixCSC by setting all nonzero values to one(Tv)
27+
return CUSPARSE.CuSparseMatrixCSC(Mat.colPtr, rowvals(Mat), bin_vals, size(Mat))
28+
end
29+
function GNNGraphs.binarize(Mat::CUSPARSE.CuSparseMatrixCSC, T::DataType)
30+
@debug "Binarizing CuSparseMatrixCSC with type $(T)"
31+
bin_vals = fill!(similar(nonzeros(Mat)), one(T))
32+
# Binarize a CuSparseMatrixCSC by setting all nonzero values to one(T)
33+
return CUSPARSE.CuSparseMatrixCSC(Mat.colPtr, rowvals(Mat), bin_vals, size(Mat))
34+
end
2235

2336

2437
function sort_edge_index(u::AnyCuArray, v::AnyCuArray)

GNNGraphs/src/query.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ function Graphs.adjacency_matrix(g::GNNGraph{<:ADJMAT_T}, T::DataType = eltype(g
235235
@assert dir [:in, :out]
236236
A = g.graph
237237
if !weighted
238-
A = binarize(A)
238+
A = binarize(A, T)
239239
end
240240
A = T != eltype(A) ? T.(A) : A
241241
return dir == :out ? A : A'
@@ -377,7 +377,7 @@ end
377377

378378
function _degree(A::AbstractMatrix, T::Type, dir::Symbol, edge_weight::Bool, num_nodes::Int)
379379
if edge_weight === false
380-
A = binarize(A)
380+
A = binarize(A, T)
381381
end
382382
A = eltype(A) != T ? T.(A) : A
383383
return dir == :out ? vec(sum(A, dims = 2)) :

GNNGraphs/src/utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,8 @@ function _rand_edges(rng, (n1, n2), m)
296296
end
297297

298298
binarize(x) = map(>(0), x)
299+
# here just to allow CUDA extension to overload this function with correct type casting
300+
binarize(x, T::DataType) = binarize(x)
299301

300302
CRC.@non_differentiable binarize(x...)
301303
CRC.@non_differentiable edge_encoding(x...)

GNNlib/ext/GNNlibCUDAExt.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,6 @@ using GNNGraphs: GNNGraph, COO_T, SPARSE_T
77

88
###### PROPAGATE SPECIALIZATIONS ####################
99

10-
## COPY_XJ
11-
12-
## avoid the fast path on gpu until we have better cuda support
13-
function GNNlib.propagate(::typeof(copy_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
14-
xi, xj::AnyCuMatrix, e)
15-
propagate((xi, xj, e) -> copy_xj(xi, xj, e), g, +, xi, xj, e)
16-
end
17-
1810
## E_MUL_XJ
1911

2012
## avoid the fast path on gpu until we have better cuda support

GNNlib/src/msgpass.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,8 @@ end
213213
## COPY_XJ
214214

215215
function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix, e)
216-
A = adjacency_matrix(g, weighted = false)
216+
@debug "copy_xj: propagating with type $(typeof(xj))"
217+
A = adjacency_matrix(g, eltype(xj); weighted = false)
217218
return xj * A
218219
end
219220

0 commit comments

Comments
 (0)