Skip to content

Commit adc6cf9

Browse files
committed
Unify binarize function signature
1 parent 7e6694d commit adc6cf9

File tree

2 files changed

+3
-9
lines changed

2 files changed

+3
-9
lines changed

GNNGraphs/ext/GNNGraphsCUDAExt.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,9 @@ GNNGraphs.dense_zeros_like(a::CUMAT_T, T::Type, sz = size(a)) = CUDA.zeros(T, sz
2121

2222
GNNGraphs.iscuarray(x::AnyCuArray) = true
2323

24-
function GNNGraphs.binarize(Mat::CUSPARSE.CuSparseMatrixCSC)
25-
bin_vals = fill!(similar(nonzeros(Mat), Bool), true)
26-
return CUSPARSE.CuSparseMatrixCSC(Mat.colPtr, rowvals(Mat), bin_vals, size(Mat))
27-
end
28-
function GNNGraphs.binarize(Mat::CUSPARSE.CuSparseMatrixCSC, T::DataType)
24+
function GNNGraphs.binarize(Mat::CUSPARSE.CuSparseMatrixCSC, T::DataType = Bool)
25+
@debug "Binarizing sparse matrix of type $(typeof(Mat)) to type $(T)"
2926
bin_vals = fill!(similar(nonzeros(Mat)), one(T))
30-
# Binarize a CuSparseMatrixCSC by setting all nonzero values to one(T)
3127
return CUSPARSE.CuSparseMatrixCSC(Mat.colPtr, rowvals(Mat), bin_vals, size(Mat))
3228
end
3329

GNNGraphs/src/utils.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -295,9 +295,7 @@ function _rand_edges(rng, (n1, n2), m)
295295
return s, t, val
296296
end
297297

298-
binarize(x) = map(>(0), x)
299-
# here just to allow CUDA extension to overload this function with correct type casting
300-
binarize(x, T::DataType) = binarize(x)
298+
binarize(x, T::DataType = Bool) = ifelse.(x .> 0, one(T), zero(T))
301299

302300
CRC.@non_differentiable binarize(x...)
303301
CRC.@non_differentiable edge_encoding(x...)

0 commit comments

Comments
 (0)