Skip to content

Commit c13ceca

Browse files
committed
Try to do SVD truncation on GPU with _ind_intersect
1 parent e1ea618 commit c13ceca

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,14 @@ function MatrixAlgebraKit._mul_herm!(C::StridedROCMatrix{T}, A::StridedROCMatrix
167167
return C
168168
end
169169
170-
# TODO: intersect on GPU arrays is not working
171-
MatrixAlgebraKit._ind_intersect(A::ROCVector{Int}, B::AbstractVector) = MatrixAlgebraKit._ind_intersect(collect(A), B)
172-
MatrixAlgebraKit._ind_intersect(A::AbstractVector, B::ROCVector{Int}) = MatrixAlgebraKit._ind_intersect(A, collect(B))
170+
MatrixAlgebraKit._ind_intersect(A::ROCVector{Int}, B) = MatrixAlgebraKit._ind_intersect(B, A)
171+
function MatrixAlgebraKit._ind_intersect(A::UnitRange, B::ROCVector{Int})
172+
sortedB = sort(B)
173+
firstB = findfirst(≥(first(A)), B)
174+
lastB = findlast(≤(last(A)), B)
175+
# ONLY works if the indices in B are contiguous!!!
176+
return B[firstB:lastB]
177+
end
173178
MatrixAlgebraKit._ind_intersect(A::ROCVector{Int}, B::ROCVector{Int}) = MatrixAlgebraKit._ind_intersect(collect(A), collect(B))
174179
175180
end

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,14 @@ function MatrixAlgebraKit._mul_herm!(C::StridedCuMatrix{T}, A::StridedCuMatrix{T
191191
return C
192192
end
193193
194-
# TODO: intersect on GPU arrays is not working
195-
MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B::AbstractVector) = MatrixAlgebraKit._ind_intersect(collect(A), B)
196-
MatrixAlgebraKit._ind_intersect(A::AbstractVector, B::CuVector{Int}) = MatrixAlgebraKit._ind_intersect(A, collect(B))
194+
MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B) = MatrixAlgebraKit._ind_intersect(B, A)
195+
function MatrixAlgebraKit._ind_intersect(A::UnitRange, B::CuVector{Int})
196+
sortedB = sort(B)
197+
firstB = findfirst(≥(first(A)), B)
198+
lastB = findlast(≤(last(A)), B)
199+
# ONLY works if the indices in B are contiguous!!!
200+
return B[firstB:lastB]
201+
end
197202
MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B::CuVector{Int}) = MatrixAlgebraKit._ind_intersect(collect(A), collect(B))
198203
199204
end

0 commit comments

Comments
 (0)