Skip to content

Commit 4c0c38e

Browse files
committed
Refactor CUDA binarize function to use Bool as the CPU version
1 parent df734e1 commit 4c0c38e

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

GNNGraphs/ext/GNNGraphsCUDAExt.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ GNNGraphs.dense_zeros_like(a::CUMAT_T, T::Type, sz = size(a)) = CUDA.zeros(T, sz
2020
# Utils
2121

2222
GNNGraphs.iscuarray(x::AnyCuArray) = true
23-
function GNNGraphs.binarize(Mat::CUSPARSE.CuSparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
24-
@debug "Binarizing CuSparseMatrixCSC with type $(Tv) and index type $(Ti)"
25-
bin_vals = fill!(similar(nonzeros(Mat)), one(Tv))
26-
# Binarize a CuSparseMatrixCSC by setting all nonzero values to one(Tv)
23+
24+
function GNNGraphs.binarize(Mat::CUSPARSE.CuSparseMatrixCSC)
25+
@debug "Binarizing CuSparseMatrixCSC"
26+
bin_vals = fill!(similar(nonzeros(Mat), Bool), true)
2727
return CUSPARSE.CuSparseMatrixCSC(Mat.colPtr, rowvals(Mat), bin_vals, size(Mat))
2828
end
2929
function GNNGraphs.binarize(Mat::CUSPARSE.CuSparseMatrixCSC, T::DataType)

0 commit comments

Comments
 (0)