Skip to content

Commit ff7d520

Browse files
author
Katharine Hyatt
committed
Comments
1 parent 3f08b07 commit ff7d520

File tree

5 files changed

+58
-138
lines changed

5 files changed

+58
-138
lines changed

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,7 @@ function MatrixAlgebraKit.truncate(::typeof(MatrixAlgebraKit.left_null!), US::Tu
167167
extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 1) - size(S, 2))))
168168
ind = MatrixAlgebraKit.findtruncated(extended_S, strategy)
169169
trunc_cols = collect(1:size(U, 2))[ind]
170-
Utrunc = similar(U, (size(U, 1), length(trunc_cols)))
171-
Utrunc .= U[:, trunc_cols]
170+
Utrunc = U[:, trunc_cols]
172171
return Utrunc, ind
173172
end
174173

test/amd/orthnull.jl

Lines changed: 2 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -7,50 +7,8 @@ using MatrixAlgebraKit: GPU_SVDAlgorithm, check_input, copy_input, default_svd_a
77
initialize_output, AbstractAlgorithm
88
using AMDGPU
99

10-
# Used to test non-AbstractMatrix codepaths.
11-
struct LinearMap{P <: AbstractMatrix}
12-
parent::P
13-
end
14-
Base.parent(A::LinearMap) = getfield(A, :parent)
15-
function Base.copy!(dest::LinearMap, src::LinearMap)
16-
copy!(parent(dest), parent(src))
17-
return dest
18-
end
19-
function LinearAlgebra.mul!(C::LinearMap, A::LinearMap, B::LinearMap)
20-
mul!(parent(C), parent(A), parent(B))
21-
return C
22-
end
23-
24-
function MatrixAlgebraKit.copy_input(::typeof(qr_compact), A::LinearMap)
25-
return LinearMap(copy_input(qr_compact, parent(A)))
26-
end
27-
function MatrixAlgebraKit.copy_input(::typeof(lq_compact), A::LinearMap)
28-
return LinearMap(copy_input(lq_compact, parent(A)))
29-
end
30-
function MatrixAlgebraKit.initialize_output(::typeof(left_orth!), A::LinearMap)
31-
return LinearMap.(initialize_output(left_orth!, parent(A)))
32-
end
33-
function MatrixAlgebraKit.initialize_output(::typeof(right_orth!), A::LinearMap)
34-
return LinearMap.(initialize_output(right_orth!, parent(A)))
35-
end
36-
function MatrixAlgebraKit.check_input(::typeof(left_orth!), A::LinearMap, VC, alg::AbstractAlgorithm)
37-
return check_input(left_orth!, parent(A), parent.(VC), alg)
38-
end
39-
function MatrixAlgebraKit.check_input(::typeof(right_orth!), A::LinearMap, VC, alg::AbstractAlgorithm)
40-
return check_input(right_orth!, parent(A), parent.(VC), alg)
41-
end
42-
function MatrixAlgebraKit.default_svd_algorithm(::Type{LinearMap{A}}; kwargs...) where {A}
43-
return default_svd_algorithm(A; kwargs...)
44-
end
45-
function MatrixAlgebraKit.initialize_output(
46-
::typeof(svd_compact!), A::LinearMap,
47-
alg::GPU_SVDAlgorithm
48-
)
49-
return LinearMap.(initialize_output(svd_compact!, parent(A), alg))
50-
end
51-
function MatrixAlgebraKit.svd_compact!(A::LinearMap, USVᴴ, alg::GPU_SVDAlgorithm)
52-
return LinearMap.(svd_compact!(parent(A), parent.(USVᴴ), alg))
53-
end
10+
# testing non-AbstractArray codepaths:
11+
include(joinpath("..", "linearmap.jl"))
5412

5513
@testset "left_orth and left_null for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
5614
rng = StableRNG(123)

test/cuda/orthnull.jl

Lines changed: 2 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -7,50 +7,8 @@ using MatrixAlgebraKit: GPU_SVDAlgorithm, check_input, copy_input, default_svd_a
77
initialize_output, AbstractAlgorithm
88
using CUDA
99

10-
# Used to test non-AbstractMatrix codepaths.
11-
struct LinearMap{P <: AbstractMatrix}
12-
parent::P
13-
end
14-
Base.parent(A::LinearMap) = getfield(A, :parent)
15-
function Base.copy!(dest::LinearMap, src::LinearMap)
16-
copy!(parent(dest), parent(src))
17-
return dest
18-
end
19-
function LinearAlgebra.mul!(C::LinearMap, A::LinearMap, B::LinearMap)
20-
mul!(parent(C), parent(A), parent(B))
21-
return C
22-
end
23-
24-
function MatrixAlgebraKit.copy_input(::typeof(qr_compact), A::LinearMap)
25-
return LinearMap(copy_input(qr_compact, parent(A)))
26-
end
27-
function MatrixAlgebraKit.copy_input(::typeof(lq_compact), A::LinearMap)
28-
return LinearMap(copy_input(lq_compact, parent(A)))
29-
end
30-
function MatrixAlgebraKit.initialize_output(::typeof(left_orth!), A::LinearMap)
31-
return LinearMap.(initialize_output(left_orth!, parent(A)))
32-
end
33-
function MatrixAlgebraKit.initialize_output(::typeof(right_orth!), A::LinearMap)
34-
return LinearMap.(initialize_output(right_orth!, parent(A)))
35-
end
36-
function MatrixAlgebraKit.check_input(::typeof(left_orth!), A::LinearMap, VC, alg::AbstractAlgorithm)
37-
return check_input(left_orth!, parent(A), parent.(VC), alg)
38-
end
39-
function MatrixAlgebraKit.check_input(::typeof(right_orth!), A::LinearMap, VC, alg::AbstractAlgorithm)
40-
return check_input(right_orth!, parent(A), parent.(VC), alg)
41-
end
42-
function MatrixAlgebraKit.default_svd_algorithm(::Type{LinearMap{A}}; kwargs...) where {A}
43-
return default_svd_algorithm(A; kwargs...)
44-
end
45-
function MatrixAlgebraKit.initialize_output(
46-
::typeof(svd_compact!), A::LinearMap,
47-
alg::GPU_SVDAlgorithm
48-
)
49-
return LinearMap.(initialize_output(svd_compact!, parent(A), alg))
50-
end
51-
function MatrixAlgebraKit.svd_compact!(A::LinearMap, USVᴴ, alg::GPU_SVDAlgorithm)
52-
return LinearMap.(svd_compact!(parent(A), parent.(USVᴴ), alg))
53-
end
10+
# testing non-AbstractArray codepaths:
11+
include(joinpath("..", "linearmap.jl"))
5412

5513
@testset "left_orth and left_null for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
5614
rng = StableRNG(123)

test/linearmap.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
module LinearMaps
2+
3+
export LinearMap
4+
5+
using MatrixAlgebraKit
6+
using MatrixAlgebraKit: AbstractAlgorithm
7+
import MatrixAlgebraKit as MAK
8+
9+
using LinearAlgebra: LinearAlgebra, lmul!, rmul!
10+
11+
# Used to test non-AbstractMatrix codepaths.
12+
struct LinearMap{P <: AbstractMatrix}
13+
parent::P
14+
end
15+
Base.parent(A::LinearMap) = A.parent
16+
17+
Base.copy!(dest::LinearMap, src::LinearMap) = (copy!(parent(dest), parent(src)); dest)
18+
19+
# necessary for orth_svd default implementations
20+
LinearAlgebra.lmul!(D::LinearMap, A::LinearMap) = (lmul!(parent(D), parent(A)); A)
21+
LinearAlgebra.rmul!(A::LinearMap, D::LinearMap) = (rmul!(parent(A), parent(D)); A)
22+
LinearAlgebra.mul!(C::LinearMap, A::LinearMap, B::LinearMap, α, β) = LinearAlgebra.mul!(parent(C), parent(A), parent(B), α, β)
23+
24+
for f in (:qr_compact, :lq_compact)
25+
@eval MAK.copy_input(::typeof($f), A::LinearMap) = LinearMap(MAK.copy_input($f, parent(A)))
26+
end
27+
28+
for f! in (:qr_compact!, :lq_compact!, :svd_compact!, :svd_full!, :svd_trunc!)
29+
@eval MAK.check_input(::typeof($f!), A::LinearMap, F, alg::AbstractAlgorithm) =
30+
MAK.check_input($f!, parent(A), parent.(F), alg)
31+
@eval MAK.initialize_output(::typeof($f!), A::LinearMap, alg::AbstractAlgorithm) =
32+
LinearMap.(MAK.initialize_output($f!, parent(A), alg))
33+
@eval MAK.$f!(A::LinearMap, F, alg::AbstractAlgorithm) =
34+
LinearMap.(MAK.$f!(parent(A), parent.(F), alg))
35+
end
36+
37+
for f! in (:left_orth!, :right_orth!)
38+
@eval MAK.check_input(::typeof($f!), A::LinearMap, F, alg) =
39+
MAK.check_input($f!, parent(A), parent.(F), alg)
40+
@eval MAK.initialize_output(::typeof($f!), A::LinearMap) =
41+
LinearMap.(MAK.initialize_output($f!, parent(A)))
42+
end
43+
44+
for f in (:qr, :lq, :svd)
45+
default_f = Symbol(:default_, f, :_algorithm)
46+
@eval MAK.$default_f(::Type{LinearMap{A}}; kwargs...) where {A} = MAK.$default_f(A; kwargs...)
47+
end
48+
end
49+
50+
using .LinearMaps

test/orthnull.jl

Lines changed: 3 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -6,55 +6,10 @@ using LinearAlgebra: LinearAlgebra, I, mul!
66
using MatrixAlgebraKit: LAPACK_SVDAlgorithm, check_input, copy_input, default_svd_algorithm,
77
initialize_output, AbstractAlgorithm
88

9-
eltypes = (Float32, Float64, ComplexF32, ComplexF64)
10-
11-
# Used to test non-AbstractMatrix codepaths.
12-
struct LinearMap{P <: AbstractMatrix}
13-
parent::P
14-
end
15-
Base.parent(A::LinearMap) = getfield(A, :parent)
16-
function Base.copy!(dest::LinearMap, src::LinearMap)
17-
copy!(parent(dest), parent(src))
18-
return dest
19-
end
20-
function LinearAlgebra.mul!(C::LinearMap, A::LinearMap, B::LinearMap)
21-
mul!(parent(C), parent(A), parent(B))
22-
return C
23-
end
9+
# testing non-AbstractArray codepaths:
10+
include("linearmap.jl")
2411

25-
function MatrixAlgebraKit.copy_input(::typeof(qr_compact), A::LinearMap)
26-
return LinearMap(copy_input(qr_compact, parent(A)))
27-
end
28-
function MatrixAlgebraKit.copy_input(::typeof(lq_compact), A::LinearMap)
29-
return LinearMap(copy_input(lq_compact, parent(A)))
30-
end
31-
function MatrixAlgebraKit.initialize_output(::typeof(left_orth!), A::LinearMap)
32-
return LinearMap.(initialize_output(left_orth!, parent(A)))
33-
end
34-
function MatrixAlgebraKit.initialize_output(::typeof(right_orth!), A::LinearMap)
35-
return LinearMap.(initialize_output(right_orth!, parent(A)))
36-
end
37-
function MatrixAlgebraKit.check_input(
38-
::typeof(left_orth!), A::LinearMap, VC, alg::AbstractAlgorithm
39-
)
40-
return check_input(left_orth!, parent(A), parent.(VC), alg)
41-
end
42-
function MatrixAlgebraKit.check_input(
43-
::typeof(right_orth!), A::LinearMap, VC, alg::AbstractAlgorithm
44-
)
45-
return check_input(right_orth!, parent(A), parent.(VC), alg)
46-
end
47-
function MatrixAlgebraKit.default_svd_algorithm(::Type{LinearMap{A}}; kwargs...) where {A}
48-
return default_svd_algorithm(A; kwargs...)
49-
end
50-
function MatrixAlgebraKit.initialize_output(
51-
::typeof(svd_compact!), A::LinearMap, alg::LAPACK_SVDAlgorithm
52-
)
53-
return LinearMap.(initialize_output(svd_compact!, parent(A), alg))
54-
end
55-
function MatrixAlgebraKit.svd_compact!(A::LinearMap, USVᴴ, alg::LAPACK_SVDAlgorithm)
56-
return LinearMap.(svd_compact!(parent(A), parent.(USVᴴ), alg))
57-
end
12+
eltypes = (Float32, Float64, ComplexF32, ComplexF64)
5813

5914
@testset "left_orth and left_null for T = $T" for T in eltypes
6015
rng = StableRNG(123)

0 commit comments

Comments
 (0)