Skip to content

Commit 40e1ca8

Browse files
committed
Use existing implementation for column permute
1 parent 13d2bf6 commit 40e1ca8

File tree

3 files changed

+3
-17
lines changed

3 files changed

+3
-17
lines changed

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -206,11 +206,4 @@ function MatrixAlgebraKit._mul_herm!(C::StridedROCMatrix{T}, A::StridedROCMatrix
206206
return C
207207
end
208208
209-
function MatrixAlgebraKit.permute_V_cols!(V, I::ROCVector{Int})
210-
I_ixs = ROCArray(collect(1:size(V, 1)))
211-
c_ixs = map(CartesianIndex, I, I_ixs)
212-
V[c_ixs] .= one(eltype(V))
213-
return V
214-
end
215-
216209
end

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -191,11 +191,4 @@ function MatrixAlgebraKit._mul_herm!(C::StridedCuMatrix{T}, A::StridedCuMatrix{T
191191
return C
192192
end
193193
194-
function MatrixAlgebraKit.permute_V_cols!(V, I::CuVector{Int})
195-
I_ixs = CuArray(collect(1:size(V, 1)))
196-
c_ixs = map(CartesianIndex, I, I_ixs)
197-
V[c_ixs] .= one(eltype(V))
198-
return V
199-
end
200-
201194
end

src/implementations/eigh.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,6 @@ function eigh_trunc_no_error!(A, DV, alg::TruncatedAlgorithm)
141141
return DVtrunc
142142
end
143143

144-
permute_V_cols!(V, I::Vector{Int}) = Base.permutecols!!(V, I)
145-
146144
# Diagonal logic
147145
# --------------
148146
function eigh_full!(A::Diagonal, DV, alg::DiagonalAlgorithm)
@@ -155,7 +153,9 @@ function eigh_full!(A::Diagonal, DV, alg::DiagonalAlgorithm)
155153
diagview(D) .= real.(diagview(A))[I]
156154
end
157155
zero!(V)
158-
V = permute_V_cols!(V, I)
156+
n = size(A, 1)
157+
I .+= (0:(n - 1)) .* n
158+
V[I] .= Ref(one(eltype(V)))
159159
return D, V
160160
end
161161

0 commit comments

Comments
 (0)