Skip to content

Commit 3f08b07

Browse files
author
Katharine Hyatt
committed
Only use the scalar method for AMDGPU
1 parent 6e83d2d commit 3f08b07

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,4 +161,15 @@ function MatrixAlgebraKit._avgdiff!(A::StridedROCMatrix, B::StridedROCMatrix)
161161
return A, B
162162
end
163163

164+
function MatrixAlgebraKit.truncate(::typeof(MatrixAlgebraKit.left_null!), US::Tuple{TU, TS}, strategy::MatrixAlgebraKit.TruncationStrategy) where {TU <: ROCArray, TS}
165+
# TODO: avoid allocation?
166+
U, S = US
167+
extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 1) - size(S, 2))))
168+
ind = MatrixAlgebraKit.findtruncated(extended_S, strategy)
169+
trunc_cols = collect(1:size(U, 2))[ind]
170+
Utrunc = similar(U, (size(U, 1), length(trunc_cols)))
171+
Utrunc .= U[:, trunc_cols]
172+
return Utrunc, ind
173+
end
174+
164175
end

src/implementations/truncation.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,7 @@ function truncate(::typeof(left_null!), (U, S), strategy::TruncationStrategy)
1717
# TODO: avoid allocation?
1818
extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 1) - size(S, 2))))
1919
ind = findtruncated(extended_S, strategy)
20-
trunc_cols = collect(1:size(U, 2))[ind]
21-
Utrunc = similar(U, (size(U, 1), length(trunc_cols)))
22-
Utrunc .= U[:, trunc_cols]
23-
return Utrunc, ind
20+
return U[:, ind], ind
2421
end
2522
function truncate(::typeof(right_null!), (S, Vᴴ), strategy::TruncationStrategy)
2623
# TODO: avoid allocation?

0 commit comments

Comments
 (0)