Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NamedDimsArrays"
uuid = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.6.0"
version = "0.6.1"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
189 changes: 148 additions & 41 deletions src/tensoralgebra.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
using LinearAlgebra: LinearAlgebra
using TensorAlgebra:
TensorAlgebra, blockedperm, contract, contract!, fusedims, permmortar, qr, splitdims, svd
TensorAlgebra,
blockedperm,
contract,
contract!,
eigen,
eigvals,
fusedims,
left_null,
lq,
permmortar,
qr,
right_null,
splitdims,
svd,
svdvals
using TensorAlgebra.BaseExtensions: BaseExtensions

function TensorAlgebra.contract!(
Expand Down Expand Up @@ -94,7 +108,7 @@
)
end
perm = blockedperm(na, nameddimsindices_fuse...)
a_fused = fusedims(unname(na), perm)
a_fused = fusedims(dename(na), perm)
return nameddimsarray(a_fused, nameddimsindices_fused)
end

Expand All @@ -107,7 +121,7 @@
split_lengths = unname.(split_namedlengths)
return fused_dim => split_lengths
end
a_split = splitdims(unname(na), splitters_unnamed...)
a_split = splitdims(dename(na), splitters_unnamed...)
names_split = Any[tuple.(nameddimsindices(na))...]
for splitter in splitters
fused_name, split_namedlengths = splitter
Expand All @@ -120,77 +134,170 @@
end

function TensorAlgebra.qr(
a::AbstractNamedDimsArray,
nameddimsindices_codomain,
nameddimsindices_domain;
positive=nothing,
a::AbstractNamedDimsArray, dimnames_codomain, dimnames_domain; kwargs...
)
@assert isnothing(positive) || !positive
q_unnamed, r_unnamed = qr(
unname(a),
nameddimsindices(a),
to_nameddimsindices(a, nameddimsindices_codomain),
to_nameddimsindices(a, nameddimsindices_domain),
)
codomain = to_nameddimsindices(a, dimnames_codomain)
domain = to_nameddimsindices(a, dimnames_domain)
q_unnamed, r_unnamed = qr(dename(a), nameddimsindices(a), codomain, domain; kwargs...)
name_q = randname(dimnames(a, 1))
name_r = name_q
namedindices_q = named(last(axes(q_unnamed)), name_q)
namedindices_r = named(first(axes(r_unnamed)), name_r)
nameddimsindices_q = (
to_nameddimsindices(a, nameddimsindices_codomain)..., namedindices_q
)
nameddimsindices_r = (namedindices_r, to_nameddimsindices(a, nameddimsindices_domain)...)
nameddimsindices_q = (codomain..., namedindices_q)
nameddimsindices_r = (namedindices_r, domain...)
q = nameddimsarray(q_unnamed, nameddimsindices_q)
r = nameddimsarray(r_unnamed, nameddimsindices_r)
return q, r
end

function TensorAlgebra.qr(a::AbstractNamedDimsArray, nameddimsindices_codomain; kwargs...)
return qr(
a,
nameddimsindices_codomain,
setdiff(nameddimsindices(a), to_nameddimsindices(a, nameddimsindices_codomain));
kwargs...,
)
function TensorAlgebra.qr(a::AbstractNamedDimsArray, dimnames_codomain; kwargs...)
codomain = to_nameddimsindices(a, dimnames_codomain)
domain = setdiff(nameddimsindices(a), codomain)
return qr(a, codomain, domain; kwargs...)
end

function LinearAlgebra.qr(a::AbstractNamedDimsArray, args...; kwargs...)
return TensorAlgebra.qr(a, args...; kwargs...)
end

function TensorAlgebra.lq(
a::AbstractNamedDimsArray, dimnames_codomain, dimnames_domain; kwargs...
)
codomain = to_nameddimsindices(a, dimnames_codomain)
domain = to_nameddimsindices(a, dimnames_domain)
l_unnamed, q_unnamed = lq(dename(a), nameddimsindices(a), codomain, domain; kwargs...)
name_l = randname(dimnames(a, 1))
name_q = name_l
namedindices_l = named(last(axes(l_unnamed)), name_l)
namedindices_q = named(first(axes(q_unnamed)), name_q)
nameddimsindices_l = (codomain..., namedindices_l)
nameddimsindices_q = (namedindices_q, domain...)
l = nameddimsarray(l_unnamed, nameddimsindices_l)
q = nameddimsarray(q_unnamed, nameddimsindices_q)
return l, q
end
function TensorAlgebra.lq(a::AbstractNamedDimsArray, dimnames_codomain; kwargs...)
codomain = to_nameddimsindices(a, dimnames_codomain)
domain = setdiff(nameddimsindices(a), codomain)
return lq(a, codomain, domain; kwargs...)
end
function LinearAlgebra.lq(a::AbstractNamedDimsArray, args...; kwargs...)
return TensorAlgebra.lq(a, args...; kwargs...)
end

function TensorAlgebra.svd(
a::AbstractNamedDimsArray, nameddimsindices_codomain, nameddimsindices_domain
a::AbstractNamedDimsArray, dimnames_codomain, dimnames_domain; kwargs...
)
codomain = to_nameddimsindices(a, dimnames_codomain)
domain = to_nameddimsindices(a, dimnames_domain)
u_unnamed, s_unnamed, v_unnamed = svd(
unname(a),
nameddimsindices(a),
to_nameddimsindices(a, nameddimsindices_codomain),
to_nameddimsindices(a, nameddimsindices_domain),
dename(a), nameddimsindices(a), codomain, domain; kwargs...
)
name_u = randname(dimnames(a, 1))
name_v = randname(dimnames(a, 1))
namedindices_u = named(last(axes(u_unnamed)), name_u)
namedindices_v = named(first(axes(v_unnamed)), name_v)
nameddimsindices_u = (
to_nameddimsindices(a, nameddimsindices_codomain)..., namedindices_u
)
nameddimsindices_u = (codomain..., namedindices_u)
nameddimsindices_s = (namedindices_u, namedindices_v)
nameddimsindices_v = (namedindices_v, to_nameddimsindices(a, nameddimsindices_domain)...)
nameddimsindices_v = (namedindices_v, domain...)
u = nameddimsarray(u_unnamed, nameddimsindices_u)
s = nameddimsarray(s_unnamed, nameddimsindices_s)
v = nameddimsarray(v_unnamed, nameddimsindices_v)
return u, s, v
end

function TensorAlgebra.svd(a::AbstractNamedDimsArray, nameddimsindices_codomain; kwargs...)
function TensorAlgebra.svd(a::AbstractNamedDimsArray, dimnames_codomain; kwargs...)
return svd(
a,
nameddimsindices_codomain,
setdiff(nameddimsindices(a), to_nameddimsindices(a, nameddimsindices_codomain));
dimnames_codomain,
setdiff(nameddimsindices(a), to_nameddimsindices(a, dimnames_codomain));
kwargs...,
)
end

function LinearAlgebra.svd(a::AbstractNamedDimsArray, args...; kwargs...)
return TensorAlgebra.svd(a, args...; kwargs...)
end

function TensorAlgebra.svdvals(

Check warning on line 218 in src/tensoralgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/tensoralgebra.jl#L218

Added line #L218 was not covered by tests
a::AbstractNamedDimsArray, dimnames_codomain, dimnames_domain; kwargs...
)
return svdvals(

Check warning on line 221 in src/tensoralgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/tensoralgebra.jl#L221

Added line #L221 was not covered by tests
dename(a),
nameddimsindices(a),
to_nameddimsindices(a, dimnames_codomain),
to_nameddimsindices(a, dimnames_domain);
kwargs...,
)
end
function TensorAlgebra.svdvals(a::AbstractNamedDimsArray, dimnames_codomain; kwargs...)
codomain = to_nameddimsindices(a, dimnames_codomain)
domain = setdiff(nameddimsindices(a), codomain)
return svdvals(a, codomain, domain; kwargs...)

Check warning on line 232 in src/tensoralgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/tensoralgebra.jl#L229-L232

Added lines #L229 - L232 were not covered by tests
end
function LinearAlgebra.svdvals(a::AbstractNamedDimsArray, args...; kwargs...)
return TensorAlgebra.svdvals(a, args...; kwargs...)

Check warning on line 235 in src/tensoralgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/tensoralgebra.jl#L234-L235

Added lines #L234 - L235 were not covered by tests
end

function TensorAlgebra.eigen(

Check warning on line 238 in src/tensoralgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/tensoralgebra.jl#L238

Added line #L238 was not covered by tests
a::AbstractNamedDimsArray, dimnames_codomain, dimnames_domain; kwargs...
)
codomain = to_nameddimsindices(a, dimnames_codomain)
domain = to_nameddimsindices(a, dimnames_domain)
d_unnamed, v_unnamed = eigen(dename(a), nameddimsindices(a), codomain, domain; kwargs...)
name_d = randname(dimnames(a, 1))
name_d′ = randname(name_d)
name_v = name_d
namedindices_d = named(last(axes(d_unnamed)), name_d)
namedindices_d′ = named(first(axes(d_unnamed)), name_d′)
namedindices_v = named(last(axes(v_unnamed)), name_v)
nameddimsindices_d = (namedindices_d′, namedindices_d)
nameddimsindices_v = (domain..., namedindices_v)
d = nameddimsarray(d_unnamed, nameddimsindices_d)
v = nameddimsarray(v_unnamed, nameddimsindices_v)
return d, v

Check warning on line 254 in src/tensoralgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/tensoralgebra.jl#L241-L254

Added lines #L241 - L254 were not covered by tests
end
function LinearAlgebra.eigen(a::AbstractNamedDimsArray, args...; kwargs...)
return TensorAlgebra.eigen(a, args...; kwargs...)

Check warning on line 257 in src/tensoralgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/tensoralgebra.jl#L256-L257

Added lines #L256 - L257 were not covered by tests
end

function TensorAlgebra.eigvals(

Check warning on line 260 in src/tensoralgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/tensoralgebra.jl#L260

Added line #L260 was not covered by tests
a::AbstractNamedDimsArray, dimnames_codomain, dimnames_domain; kwargs...
)
codomain = to_nameddimsindices(a, dimnames_codomain)
domain = to_nameddimsindices(a, dimnames_domain)
return eigvals(dename(a), nameddimsindices(a), codomain, domain; kwargs...)

Check warning on line 265 in src/tensoralgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/tensoralgebra.jl#L263-L265

Added lines #L263 - L265 were not covered by tests
end
function LinearAlgebra.eigvals(a::AbstractNamedDimsArray, args...; kwargs...)
return TensorAlgebra.eigvals(a, args...; kwargs...)

Check warning on line 268 in src/tensoralgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/tensoralgebra.jl#L267-L268

Added lines #L267 - L268 were not covered by tests
end

function TensorAlgebra.left_null(
a::AbstractNamedDimsArray, dimnames_codomain, dimnames_domain; kwargs...
)
codomain = to_nameddimsindices(a, dimnames_codomain)
domain = to_nameddimsindices(a, dimnames_domain)
n_unnamed = left_null(dename(a), nameddimsindices(a), codomain, domain; kwargs...)
name_n = randname(dimnames(a, 1))
namedindices_n = named(last(axes(n_unnamed)), name_n)
nameddimsindices_n = (codomain..., namedindices_n)
return nameddimsarray(n_unnamed, nameddimsindices_n)
end
function TensorAlgebra.left_null(a::AbstractNamedDimsArray, dimnames_codomain; kwargs...)
codomain = to_nameddimsindices(a, dimnames_codomain)
domain = setdiff(nameddimsindices(a), codomain)
return left_null(a, codomain, domain; kwargs...)
end

function TensorAlgebra.right_null(
a::AbstractNamedDimsArray, dimnames_codomain, dimnames_domain; kwargs...
)
codomain = to_nameddimsindices(a, dimnames_codomain)
domain = to_nameddimsindices(a, dimnames_domain)
n_unnamed = right_null(dename(a), nameddimsindices(a), codomain, domain; kwargs...)
name_n = randname(dimnames(a, 1))
namedindices_n = named(first(axes(n_unnamed)), name_n)
nameddimsindices_n = (namedindices_n, domain...)
return nameddimsarray(n_unnamed, nameddimsindices_n)
end
function TensorAlgebra.right_null(a::AbstractNamedDimsArray, dimnames_codomain; kwargs...)
codomain = to_nameddimsindices(a, dimnames_codomain)
domain = setdiff(nameddimsindices(a), codomain)
return right_null(a, codomain, domain; kwargs...)
end
39 changes: 32 additions & 7 deletions test/test_tensoralgebra.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using LinearAlgebra: qr, svd
using NamedDimsArrays: namedoneto, dename
using LinearAlgebra: lq, norm, qr, svd
using NamedDimsArrays: dename, left_null, nameddimsindices, namedoneto, right_null
using TensorAlgebra: TensorAlgebra, contract, fusedims, splitdims
using Test: @test, @testset, @test_broken
elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
Expand Down Expand Up @@ -43,20 +43,24 @@ elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test dename(na_split, ("j", "i", "b")) ≈
reshape(dename(na, ("a", "b")), (dename(j), dename(i), dename(b)))
end
@testset "qr" begin
@testset "qr/lq" begin
dims = (2, 2, 2, 2)
i, j, k, l = namedoneto.(dims, ("i", "j", "k", "l"))

a = randn(elt, i, j)
# TODO: Should this be allowed?
# TODO: Add support for specifying new name.
q, r = qr(a, (i,))
@test q * r ≈ a
for f in (qr, lq)
x, y = f(a, (i,))
@test x * y ≈ a
end

a = randn(elt, i, j, k, l)
# TODO: Add support for specifying new name.
q, r = qr(a, (i, k), (j, l))
@test q * r ≈ a
for f in (qr, lq)
x, y = f(a, (i, k), (j, l))
@test x * y ≈ a
end
end
@testset "svd" begin
dims = (2, 2, 2, 2)
Expand All @@ -72,5 +76,26 @@ elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
# TODO: Add support for specifying new name.
u, s, v = svd(a, (i, k), (j, l))
@test u * s * v ≈ a

# Test truncation.
a = randn(elt, i, j, k, l)
u, s, v = svd(a, (i, k), (j, l); trunc=(; maxrank=2))
@test u * s * v ≉ a
@test Int.(Tuple(size(s))) == (2, 2)
end
@testset "left_null/eight_null" begin
dims = (2, 2, 2, 2)
i, j, k, l = namedoneto.(dims, ("i", "j", "k", "l"))

a = randn(elt, i, j, k, l)
# TODO: Add support for specifying new name.
for n in (left_null(a, (i, k), (j, l)), left_null(a, (i, k)))
@test (i, k) ⊆ nameddimsindices(n)
@test norm(n * a) ≈ 0
end
for n in (right_null(a, (i, k), (j, l)), right_null(a, (i, k)))
@test (j, l) ⊆ nameddimsindices(n)
@test norm(n * a) ≈ 0
end
end
end
Loading