Skip to content

Commit d8c2ecb

Browse files
Replace allowed_getindex with only() for GPU compatibility
Co-authored-by: albertomercurio <61953577+albertomercurio@users.noreply.github.com>
1 parent 5c36d6c commit d8c2ecb

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

src/matrix_coo/matrix_coo.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ function Base.:+(A::DeviceSparseMatrixCOO, B::DeviceSparseMatrixCOO)
408408

409409
# Compute write indices using cumsum
410410
write_indices = _cumsum_AK(keep_mask)
411-
nnz_final = allowed_getindex(write_indices, nnz_concat)
411+
nnz_final = only(write_indices[nnz_concat:nnz_concat])
412412

413413
# Allocate final arrays
414414
rowind_C = similar(getrowind(A), nnz_final)
@@ -514,7 +514,7 @@ for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparse
514514
kernel_mark!(keep_mask, rowind_sorted, colind_sorted, nnz_concat; ndrange = (nnz_concat,))
515515

516516
write_indices = _cumsum_AK(keep_mask)
517-
nnz_final = allowed_getindex(write_indices, nnz_concat)
517+
nnz_final = only(write_indices[nnz_concat:nnz_concat])
518518

519519
rowind_C = similar(getrowind(_A), nnz_final)
520520
colind_C = similar(getcolind(_A), nnz_final)

src/matrix_csc/matrix_csc.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ function Base.:+(A::DeviceSparseMatrixCSC, B::DeviceSparseMatrixCSC)
379379
colptr_C[1:1] .= one(Ti)
380380

381381
# Allocate result arrays
382-
nnz_total = allowed_getindex(colptr_C, n + 1) - one(Ti)
382+
nnz_total = only(colptr_C[n+1:n+1]) - one(Ti)
383383
rowval_C = similar(getrowval(A), nnz_total)
384384
nzval_C = similar(nonzeros(A), Tv, nnz_total)
385385

src/matrix_csr/matrix_csr.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ function Base.:+(A::DeviceSparseMatrixCSR, B::DeviceSparseMatrixCSR)
376376
rowptr_C[1:1] .= one(Ti)
377377

378378
# Allocate result arrays
379-
nnz_total = allowed_getindex(rowptr_C, m + 1) - one(Ti)
379+
nnz_total = only(rowptr_C[m+1:m+1]) - one(Ti)
380380
colval_C = similar(getcolval(A), nnz_total)
381381
nzval_C = similar(nonzeros(A), Tv, nnz_total)
382382

0 commit comments

Comments
 (0)