From 8ec462096b8ee86d81cb70a7b4dfd990e5b9bff5 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 17 Jan 2025 15:36:15 -0500 Subject: [PATCH 1/2] Define svd --- src/tensoralgebra.jl | 91 ++++++++++++++++++++++++++----- test/basics/test_tensoralgebra.jl | 29 +++++++--- 2 files changed, 99 insertions(+), 21 deletions(-) diff --git a/src/tensoralgebra.jl b/src/tensoralgebra.jl index 2c9f8d3..61273cc 100644 --- a/src/tensoralgebra.jl +++ b/src/tensoralgebra.jl @@ -1,5 +1,6 @@ -using LinearAlgebra: LinearAlgebra, qr -using TensorAlgebra: TensorAlgebra, blockedperm, contract, contract!, fusedims, splitdims +using LinearAlgebra: LinearAlgebra +using TensorAlgebra: + TensorAlgebra, blockedperm, contract, contract!, fusedims, qr, splitdims, svd using TensorAlgebra.BaseExtensions: BaseExtensions function TensorAlgebra.contract!( @@ -35,6 +36,22 @@ function Base.:*(a1::AbstractNamedDimsArray, a2::AbstractNamedDimsArray) return contract(a1, a2) end +# Left associative fold/reduction. +# Circumvent Base definitions: +# ```julia +# *(A::AbstractMatrix, B::AbstractMatrix, C::AbstractMatrix) +# *(A::AbstractMatrix, B::AbstractMatrix, C::AbstractMatrix, D::AbstractMatrix) +# ``` +# that optimize matrix multiplication sequence. +function Base.:*( + a1::AbstractNamedDimsArray, + a2::AbstractNamedDimsArray, + a3::AbstractNamedDimsArray, + a_rest::AbstractNamedDimsArray..., +) + return *(*(a1, a2), a3, a_rest...) +end + function LinearAlgebra.mul!( a_dest::AbstractNamedDimsArray, a1::AbstractNamedDimsArray, @@ -99,28 +116,33 @@ function TensorAlgebra.splitdims(na::AbstractNamedDimsArray, splitters::Pair...) return nameddims(a_split, names_split) end -function LinearAlgebra.qr( +function TensorAlgebra.qr( a::AbstractNamedDimsArray, nameddimsindices_codomain, nameddimsindices_domain; positive=nothing, ) @assert isnothing(positive) || !positive - # TODO: This should be `TensorAlgebra.qr` rather than overloading `LinearAlgebra.qr`. - # TODO: Don't require wrapping in `Tuple`. - q, r = qr( + q_unnamed, r_unnamed = qr( unname(a), - Tuple(nameddimsindices(a)), - Tuple(to_nameddimsindices(a, nameddimsindices_codomain)), - Tuple(to_nameddimsindices(a, nameddimsindices_domain)), + nameddimsindices(a), + to_nameddimsindices(a, nameddimsindices_codomain), + to_nameddimsindices(a, nameddimsindices_domain), + ) + name_q = randname(dimnames(a, 1)) + name_r = name_q + namedindices_q = named(last(axes(q_unnamed)), name_q) + namedindices_r = named(first(axes(r_unnamed)), name_r) + nameddimsindices_q = ( + to_nameddimsindices(a, nameddimsindices_codomain)..., namedindices_q ) - name_qr = randname(nameddimsindices(a)[1]) - nameddimsindices_q = (to_nameddimsindices(a, nameddimsindices_codomain)..., name_qr) - nameddimsindices_r = (name_qr, to_nameddimsindices(a, nameddimsindices_domain)...) - return nameddims(q, nameddimsindices_q), nameddims(r, nameddimsindices_r) + nameddimsindices_r = (namedindices_r, to_nameddimsindices(a, nameddimsindices_domain)...) + q = nameddims(q_unnamed, nameddimsindices_q) + r = nameddims(r_unnamed, nameddimsindices_r) + return q, r end -function LinearAlgebra.qr(a::AbstractNamedDimsArray, nameddimsindices_codomain; kwargs...) +function TensorAlgebra.qr(a::AbstractNamedDimsArray, nameddimsindices_codomain; kwargs...) return qr( a, nameddimsindices_codomain, @@ -128,3 +150,44 @@ function LinearAlgebra.qr(a::AbstractNamedDimsArray, nameddimsindices_codomain; kwargs..., ) end + +function LinearAlgebra.qr(a::AbstractNamedDimsArray, args...; kwargs...) + return TensorAlgebra.qr(a, args...; kwargs...) +end + +function TensorAlgebra.svd( + a::AbstractNamedDimsArray, nameddimsindices_codomain, nameddimsindices_domain +) + u_unnamed, s_unnamed, v_unnamed = svd( + unname(a), + nameddimsindices(a), + to_nameddimsindices(a, nameddimsindices_codomain), + to_nameddimsindices(a, nameddimsindices_domain), + ) + name_u = randname(dimnames(a, 1)) + name_v = randname(dimnames(a, 1)) + namedindices_u = named(last(axes(u_unnamed)), name_u) + namedindices_v = named(first(axes(v_unnamed)), name_v) + nameddimsindices_u = ( + to_nameddimsindices(a, nameddimsindices_codomain)..., namedindices_u + ) + nameddimsindices_s = (namedindices_u, namedindices_v) + nameddimsindices_v = (namedindices_v, to_nameddimsindices(a, nameddimsindices_domain)...) + u = nameddims(u_unnamed, nameddimsindices_u) + s = nameddims(s_unnamed, nameddimsindices_s) + v = nameddims(v_unnamed, nameddimsindices_v) + return u, s, v +end + +function TensorAlgebra.svd(a::AbstractNamedDimsArray, nameddimsindices_codomain; kwargs...) + return svd( + a, + nameddimsindices_codomain, + setdiff(nameddimsindices(a), to_nameddimsindices(a, nameddimsindices_codomain)); + kwargs..., + ) +end + +function LinearAlgebra.svd(a::AbstractNamedDimsArray, args...; kwargs...) + return TensorAlgebra.svd(a, args...; kwargs...) +end diff --git a/test/basics/test_tensoralgebra.jl b/test/basics/test_tensoralgebra.jl index 6979df1..f9576c5 100644 --- a/test/basics/test_tensoralgebra.jl +++ b/test/basics/test_tensoralgebra.jl @@ -1,4 +1,4 @@ -using LinearAlgebra: qr +using LinearAlgebra: qr, svd using NamedDimsArrays: namedoneto, dename using TensorAlgebra: TensorAlgebra, contract, fusedims, splitdims using Test: @test, @testset, @test_broken @@ -47,15 +47,30 @@ elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) dims = (2, 2, 2, 2) i, j, k, l = namedoneto.(dims, ("i", "j", "k", "l")) - na = randn(elt, i, j) + a = randn(elt, i, j) # TODO: Should this be allowed? # TODO: Add support for specifying new name. - q, r = qr(na, (i,)) - @test q * r ≈ na + q, r = qr(a, (i,)) + @test q * r ≈ a - na = randn(elt, i, j, k, l) + a = randn(elt, i, j, k, l) + # TODO: Add support for specifying new name. + q, r = qr(a, (i, k), (j, l)) + @test q * r ≈ a + end + @testset "svd" begin + dims = (2, 2, 2, 2) + i, j, k, l = namedoneto.(dims, ("i", "j", "k", "l")) + + a = randn(elt, i, j) + # TODO: Should this be allowed? + # TODO: Add support for specifying new name. + u, s, v = svd(a, (i,)) + @test u * s * v ≈ a + + a = randn(elt, i, j, k, l) # TODO: Add support for specifying new name. - q, r = qr(na, (i, k), (j, l)) - @test contract(q, r) ≈ na + u, s, v = svd(a, (i, k), (j, l)) + @test u * s * v ≈ a end end From c261918ec6e7eae6dc31f96e06101f53c46ba89a Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 17 Jan 2025 15:40:00 -0500 Subject: [PATCH 2/2] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index da73bca..92c102a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NamedDimsArrays" uuid = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" authors = ["ITensor developers and contributors"] -version = "0.3.7" +version = "0.3.8" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"