diff --git a/GNNGraphs/ext/GNNGraphsCUDAExt.jl b/GNNGraphs/ext/GNNGraphsCUDAExt.jl index 0d839f58c..af9e9f820 100644 --- a/GNNGraphs/ext/GNNGraphsCUDAExt.jl +++ b/GNNGraphs/ext/GNNGraphsCUDAExt.jl @@ -4,6 +4,7 @@ using CUDA using Random, Statistics, LinearAlgebra using GNNGraphs using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T +using SparseArrays const CUMAT_T = Union{CUDA.AnyCuMatrix, CUDA.CUSPARSE.CuSparseMatrix} @@ -20,6 +21,11 @@ GNNGraphs.dense_zeros_like(a::CUMAT_T, T::Type, sz = size(a)) = CUDA.zeros(T, sz GNNGraphs.iscuarray(x::AnyCuArray) = true +function GNNGraphs.binarize(Mat::CUSPARSE.CuSparseMatrixCSC, T::DataType = Bool) + bin_vals = fill!(similar(nonzeros(Mat)), one(T)) + return CUSPARSE.CuSparseMatrixCSC(Mat.colPtr, rowvals(Mat), bin_vals, size(Mat)) +end + function sort_edge_index(u::AnyCuArray, v::AnyCuArray) dev = get_device(u) diff --git a/GNNGraphs/src/query.jl b/GNNGraphs/src/query.jl index 879aee9df..b427c5916 100644 --- a/GNNGraphs/src/query.jl +++ b/GNNGraphs/src/query.jl @@ -235,7 +235,7 @@ function Graphs.adjacency_matrix(g::GNNGraph{<:ADJMAT_T}, T::DataType = eltype(g @assert dir ∈ [:in, :out] A = g.graph if !weighted - A = binarize(A) + A = binarize(A, T) end A = T != eltype(A) ? T.(A) : A return dir == :out ? A : A' @@ -377,7 +377,7 @@ end function _degree(A::AbstractMatrix, T::Type, dir::Symbol, edge_weight::Bool, num_nodes::Int) if edge_weight === false - A = binarize(A) + A = binarize(A, T) end A = eltype(A) != T ? T.(A) : A return dir == :out ? vec(sum(A, dims = 2)) : diff --git a/GNNGraphs/src/utils.jl b/GNNGraphs/src/utils.jl index 9c2b94057..9b12f43f1 100644 --- a/GNNGraphs/src/utils.jl +++ b/GNNGraphs/src/utils.jl @@ -295,7 +295,7 @@ function _rand_edges(rng, (n1, n2), m) return s, t, val end -binarize(x) = map(>(0), x) +binarize(x, T::DataType = Bool) = ifelse.(x .> 0, one(T), zero(T)) CRC.@non_differentiable binarize(x...) CRC.@non_differentiable edge_encoding(x...)