diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl index befb4e0b..0ca43183 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -167,9 +167,8 @@ function MatrixAlgebraKit._mul_herm!(C::StridedROCMatrix{T}, A::StridedROCMatrix return C end -# TODO: intersect on GPU arrays is not working -MatrixAlgebraKit._ind_intersect(A::ROCVector{Int}, B::AbstractVector) = MatrixAlgebraKit._ind_intersect(collect(A), B) -MatrixAlgebraKit._ind_intersect(A::AbstractVector, B::ROCVector{Int}) = MatrixAlgebraKit._ind_intersect(A, collect(B)) -MatrixAlgebraKit._ind_intersect(A::ROCVector{Int}, B::ROCVector{Int}) = MatrixAlgebraKit._ind_intersect(collect(A), collect(B)) +# TODO: intersect doesn't work on GPU +MatrixAlgebraKit._ind_intersect(A::ROCVector{Int}, B::ROCVector{Int}) = + MatrixAlgebraKit._ind_intersect(collect(A), collect(B)) end diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index 432f176a..8bb09db1 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -191,9 +191,8 @@ function MatrixAlgebraKit._mul_herm!(C::StridedCuMatrix{T}, A::StridedCuMatrix{T return C end -# TODO: intersect on GPU arrays is not working -MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B::AbstractVector) = MatrixAlgebraKit._ind_intersect(collect(A), B) -MatrixAlgebraKit._ind_intersect(A::AbstractVector, B::CuVector{Int}) = MatrixAlgebraKit._ind_intersect(A, collect(B)) -MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B::CuVector{Int}) = MatrixAlgebraKit._ind_intersect(collect(A), collect(B)) +# TODO: intersect doesn't work on GPU +MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B::CuVector{Int}) = + MatrixAlgebraKit._ind_intersect(collect(A), collect(B)) end diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index be730bed..945e0772 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -126,6 +126,15 @@ function _ind_intersect(A::AbstractVector{Bool}, B::AbstractVector) end _ind_intersect(A::AbstractVector, B::AbstractVector{Bool}) = _ind_intersect(B, A) _ind_intersect(A::AbstractVector{Bool}, B::AbstractVector{Bool}) = A .& B + +# when one of the ind selections is a unitrange, filter is more efficient than intersect +# since we know both selections only contain unique entries +# (This is also more GPU-friendly!) +_ind_intersect(A::AbstractUnitRange{Int}, B::AbstractUnitRange{Int}) = intersect(A, B) +_ind_intersect(A::AbstractVector{Int}, B::AbstractUnitRange{Int}) = filter(in(B), A) +_ind_intersect(A::AbstractUnitRange{Int}, B::AbstractVector{Int}) = _ind_intersect(B, A) + +# when all else fails, call intersect _ind_intersect(A, B) = intersect(A, B) # Truncation error