diff --git a/Project.toml b/Project.toml index 29e0e8a..8da5421 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.1.5" +version = "0.1.6" [deps] DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" diff --git a/src/KroneckerArrays.jl b/src/KroneckerArrays.jl index 06fd869..3ae3fad 100644 --- a/src/KroneckerArrays.jl +++ b/src/KroneckerArrays.jl @@ -579,13 +579,17 @@ end using MatrixAlgebraKit: copy_input, eig_full, + eig_vals, eigh_full, + eigh_vals, qr_compact, qr_full, + left_null, left_orth, left_polar, lq_compact, lq_full, + right_null, right_orth, right_polar, svd_compact, @@ -741,13 +745,17 @@ _copy_input_squareeye(f::F, a::SquareEye) where {F} = a for f in [ :eig_full, + :eig_vals, :eigh_full, + :eigh_vals, :qr_compact, :qr_full, + :left_null, :left_orth, :left_polar, :lq_compact, :lq_full, + :right_null, :right_orth, :right_polar, :svd_compact, @@ -791,6 +799,12 @@ end _initialize_output_squareeye(f::F, a) where {F} = initialize_output(f, a) _initialize_output_squareeye(f::F, a, alg) where {F} = initialize_output(f, a, alg) +for f in [:left_null!, :right_null!] + @eval begin + _initialize_output_squareeye(::typeof($f), a::SquareEye) = a + _initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = a + end +end for f in [ :eig_full!, :eigh_full!, @@ -857,4 +871,74 @@ for f in [ end end +for f in [:left_null!, :right_null!] + f′ = Symbol("_", f, "_squareeye") + @eval begin + $f′(a, F; kwargs...) = $f(a, F; kwargs...) + $f′(a::SquareEye, F) = F + end + for T in [:SquareEyeKronecker, :KroneckerSquareEye] + @eval begin + function MatrixAlgebraKit.initialize_output(::typeof($f), a::$T) + return _initialize_output_squareeye($f, a.a) ⊗ _initialize_output_squareeye($f, a.b) + end + function MatrixAlgebraKit.$f(a::$T, F; kwargs1=(;), kwargs2=(;), kwargs...) + $f′(a.a, F.a; kwargs..., kwargs1...) + $f′(a.b, F.b; kwargs..., kwargs2...) + return F + end + end + end +end + +function MatrixAlgebraKit.initialize_output(f::typeof(left_null!), a::SquareEyeSquareEye) + return _initialize_output_squareeye(f, a.a) ⊗ _initialize_output_squareeye(f, a.b) +end +function MatrixAlgebraKit.left_null!( + a::SquareEyeSquareEye, F; kwargs1=(;), kwargs2=(;), kwargs... +) + return throw(MethodError(left_null!, (a, F))) +end + +function MatrixAlgebraKit.initialize_output(f::typeof(right_null!), a::SquareEyeSquareEye) + return _initialize_output_squareeye(f, a.a) ⊗ _initialize_output_squareeye(f, a.b) +end +function MatrixAlgebraKit.right_null!( + a::SquareEyeSquareEye, F; kwargs1=(;), kwargs2=(;), kwargs... +) + return throw(MethodError(right_null!, (a, F))) +end + +for f in [:eig_vals!, :eigh_vals!, :svd_vals!] + @eval begin + _initialize_output_squareeye(::typeof($f), a::SquareEye) = parent(a) + _initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = parent(a) + end +end + +for f in [:eig_vals!, :eigh_vals!, :svd_vals!] + f′ = Symbol("_", f, "_squareeye") + @eval begin + $f′(a, F, alg; kwargs...) = $f(a, F, alg; kwargs...) + $f′(a, F, alg::SquareEyeAlgorithm) = F + end + for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye] + @eval begin + function MatrixAlgebraKit.initialize_output( + ::typeof($f), a::$T, alg::KroneckerAlgorithm + ) + return _initialize_output_squareeye($f, a.a, alg.a) ⊗ + _initialize_output_squareeye($f, a.b, alg.b) + end + function MatrixAlgebraKit.$f( + a::$T, F, alg::KroneckerAlgorithm; kwargs1=(;), kwargs2=(;), kwargs... + ) + $f′(a.a, F.a, alg.a; kwargs..., kwargs1...) + $f′(a.b, F.b, alg.b; kwargs..., kwargs2...) + return F + end + end + end +end + end diff --git a/test/test_matrixalgebrakit.jl b/test/test_matrixalgebrakit.jl index a26508d..82943cd 100644 --- a/test/test_matrixalgebrakit.jl +++ b/test/test_matrixalgebrakit.jl @@ -1,4 +1,4 @@ -using FillArrays: Eye +using FillArrays: Eye, Ones using KroneckerArrays: ⊗, arguments using LinearAlgebra: Hermitian, I, diag, hermitianpart, norm using MatrixAlgebraKit: @@ -124,13 +124,8 @@ end # TODO: # eig_trunc - # eig_vals # eigh_trunc - # eigh_vals - # left_null - # right_null # svd_trunc - # svd_vals for f in (eig_full, eigh_full) a = Eye(3) ⊗ parent(hermitianpart(randn(3, 3))) @@ -154,17 +149,39 @@ end @test arguments(v, 2) isa Eye end + for f in (eig_vals, eigh_vals) + a = Eye(3) ⊗ parent(hermitianpart(randn(3, 3))) + d = @constinferred f(a) + d′ = f(Matrix(a)) + @test sort(Vector(d); by=abs) ≈ sort(d′; by=abs) + @test arguments(d, 1) isa Ones + @test arguments(d, 2) ≈ f(arguments(a, 2)) + + a = parent(hermitianpart(randn(3, 3))) ⊗ Eye(3) + d = @constinferred f(a) + d′ = f(Matrix(a)) + @test sort(Vector(d); by=abs) ≈ sort(d′; by=abs) + @test arguments(d, 2) isa Ones + @test arguments(d, 1) ≈ f(arguments(a, 1)) + + a = Eye(3) ⊗ Eye(3) + d = @constinferred f(a) + @test d == Ones(3) ⊗ Ones(3) + @test arguments(d, 1) isa Ones + @test arguments(d, 2) isa Ones + end + for f in ( left_orth, left_polar, lq_compact, lq_full, qr_compact, qr_full, right_orth, right_polar ) a = Eye(3) ⊗ randn(3, 3) - x, y = f(a) + x, y = @constinferred f(a) @test x * y ≈ a @test arguments(x, 1) isa Eye @test arguments(y, 1) isa Eye a = randn(3, 3) ⊗ Eye(3) - x, y = f(a) + x, y = @constinferred f(a) @test x * y ≈ a @test arguments(x, 2) isa Eye @test arguments(y, 2) isa Eye @@ -180,21 +197,21 @@ end for f in (svd_compact, svd_full) a = Eye(3) ⊗ randn(3, 3) - u, s, v = f(a) + u, s, v = @constinferred f(a) @test u * s * v ≈ a @test arguments(u, 1) isa Eye @test arguments(s, 1) isa Eye @test arguments(v, 1) isa Eye a = randn(3, 3) ⊗ Eye(3) - u, s, v = f(a) + u, s, v = @constinferred f(a) @test u * s * v ≈ a @test arguments(u, 2) isa Eye @test arguments(s, 2) isa Eye @test arguments(v, 2) isa Eye a = Eye(3) ⊗ Eye(3) - u, s, v = f(a) + u, s, v = @constinferred f(a) @test u * s * v ≈ a @test arguments(u, 1) isa Eye @test arguments(s, 1) isa Eye @@ -203,4 +220,52 @@ end @test arguments(s, 2) isa Eye @test arguments(v, 2) isa Eye end + + a = Eye(3) ⊗ randn(3, 3) + d = @constinferred svd_vals(a) + d′ = svd_vals(Matrix(a)) + @test sort(Vector(d); by=abs) ≈ sort(d′; by=abs) + @test arguments(d, 1) isa Ones + @test arguments(d, 2) ≈ svd_vals(arguments(a, 2)) + + a = randn(3, 3) ⊗ Eye(3) + d = @constinferred svd_vals(a) + d′ = svd_vals(Matrix(a)) + @test sort(Vector(d); by=abs) ≈ sort(d′; by=abs) + @test arguments(d, 2) isa Ones + @test arguments(d, 1) ≈ svd_vals(arguments(a, 1)) + + a = Eye(3) ⊗ Eye(3) + d = @constinferred svd_vals(a) + @test d == Ones(3) ⊗ Ones(3) + @test arguments(d, 1) isa Ones + @test arguments(d, 2) isa Ones + + # left_null + a = Eye(3) ⊗ randn(3, 3) + n = @constinferred left_null(a) + @test norm(n' * a) ≈ 0 + @test arguments(n, 1) isa Eye + + a = randn(3, 3) ⊗ Eye(3) + n = @constinferred left_null(a) + @test norm(n' * a) ≈ 0 + @test arguments(n, 2) isa Eye + + a = Eye(3) ⊗ Eye(3) + @test_throws MethodError left_null(a) + + # right_null + a = Eye(3) ⊗ randn(3, 3) + n = @constinferred right_null(a) + @test norm(a * n') ≈ 0 + @test arguments(n, 1) isa Eye + + a = randn(3, 3) ⊗ Eye(3) + n = @constinferred right_null(a) + @test norm(a * n') ≈ 0 + @test arguments(n, 2) isa Eye + + a = Eye(3) ⊗ Eye(3) + @test_throws MethodError right_null(a) end