diff --git a/Project.toml b/Project.toml index 4f426d1..41128d7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NamedDimsArrays" uuid = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" authors = ["ITensor developers and contributors"] -version = "0.7.16" +version = "0.7.17" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -37,7 +37,7 @@ LinearAlgebra = "1.10" MapBroadcast = "0.1.6" Random = "1.10" SimpleTraits = "0.9.4" -TensorAlgebra = "0.3, 0.4" +TensorAlgebra = "0.4.1" TupleTools = "1.6.0" TypeParameterAccessors = "0.4" julia = "1.10" diff --git a/docs/Project.toml b/docs/Project.toml index c71bbb5..c2ccacb 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -7,6 +7,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] NamedDimsArrays = "0.7" -TensorAlgebra = "0.3, 0.4" +TensorAlgebra = "0.4" Literate = "2" Documenter = "1" diff --git a/examples/Project.toml b/examples/Project.toml index 0773a01..10c4a90 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -5,4 +5,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] NamedDimsArrays = "0.7" -TensorAlgebra = "0.3, 0.4" +TensorAlgebra = "0.4" diff --git a/src/abstractnameddimsarray.jl b/src/abstractnameddimsarray.jl index 5faee71..6a00417 100644 --- a/src/abstractnameddimsarray.jl +++ b/src/abstractnameddimsarray.jl @@ -309,6 +309,11 @@ function replacenameddimsindices(a::AbstractNamedDimsArray, replacements::Pair.. new_nameddimsindices = named.(dename.(old_nameddimsindices), last.(replacements)) return replacenameddimsindices(a, (old_nameddimsindices .=> new_nameddimsindices)...) end +function replacenameddimsindices(a::AbstractNamedDimsArray, replacements::Dict) + return replacenameddimsindices(a) do name + return get(replacements, name, name) + end +end function mapnameddimsindices(f, a::AbstractNamedDimsArray) return setnameddimsindices(a, map(f, nameddimsindices(a))) end diff --git a/src/nameddimsoperator.jl b/src/nameddimsoperator.jl index d7cebef..c27f7cc 100644 --- a/src/nameddimsoperator.jl +++ b/src/nameddimsoperator.jl @@ -127,6 +127,15 @@ Base.:*(a::AbstractNamedDimsOperator, b::AbstractNamedDimsOperator) = state(a) * Base.:*(a::AbstractNamedDimsOperator, b::AbstractNamedDimsArray) = state(a) * state(b) Base.:*(a::AbstractNamedDimsArray, b::AbstractNamedDimsOperator) = state(a) * state(b) +for f in MATRIX_FUNCTIONS + @eval begin + function Base.$f(a::AbstractNamedDimsOperator) + c = codomain(a) + d = domain(a) + return operator($f(state(a), c, d), c .=> d) + end + end +end struct NamedDimsOperator{T,N,P<:AbstractNamedDimsArray{T,N},D,C} <: AbstractNamedDimsOperator{T,N} parent::P diff --git a/src/tensoralgebra.jl b/src/tensoralgebra.jl index 95da841..05328e3 100644 --- a/src/tensoralgebra.jl +++ b/src/tensoralgebra.jl @@ -4,6 +4,7 @@ using TensorAlgebra: blockedperm, contract, contract!, + contractadd!, eigen, eigvals, factorize, @@ -25,14 +26,14 @@ using TensorAlgebra: using TensorAlgebra.BaseExtensions: BaseExtensions using TupleTools: TupleTools -function TensorAlgebra.contract!( +function TensorAlgebra.contractadd!( a_dest::AbstractNamedDimsArray, a1::AbstractNamedDimsArray, a2::AbstractNamedDimsArray, - α::Number=true, - β::Number=false, + α::Number, + β::Number, ) - contract!( + contractadd!( dename(a_dest), nameddimsindices(a_dest), dename(a1), @@ -45,6 +46,12 @@ function TensorAlgebra.contract!( return a_dest end +function TensorAlgebra.contract!( + a_dest::AbstractNamedDimsArray, a1::AbstractNamedDimsArray, a2::AbstractNamedDimsArray +) + return contractadd!(a_dest, a1, a2, true, false) +end + function TensorAlgebra.contract(a1::AbstractNamedDimsArray, a2::AbstractNamedDimsArray) a_dest, nameddimsindices_dest = contract( dename(a1), nameddimsindices(a1), dename(a2), nameddimsindices(a2) @@ -79,10 +86,17 @@ function LinearAlgebra.mul!( a_dest::AbstractNamedDimsArray, a1::AbstractNamedDimsArray, a2::AbstractNamedDimsArray, - α::Number=true, - β::Number=false, + α::Number, + β::Number, ) - contract!(a_dest, a1, a2, α, β) + contractadd!(a_dest, a1, a2, α, β) + return a_dest +end + +function LinearAlgebra.mul!( + a_dest::AbstractNamedDimsArray, a1::AbstractNamedDimsArray, a2::AbstractNamedDimsArray +) + contract!(a_dest, a1, a2) return a_dest end @@ -301,3 +315,50 @@ function TensorAlgebra.right_null(a::AbstractNamedDimsArray, dimnames_codomain; domain = setdiff(nameddimsindices(a), codomain) return right_null(a, codomain, domain; kwargs...) end + +const MATRIX_FUNCTIONS = [ + :exp, + :cis, + :log, + :sqrt, + :cbrt, + :cos, + :sin, + :tan, + :csc, + :sec, + :cot, + :cosh, + :sinh, + :tanh, + :csch, + :sech, + :coth, + :acos, + :asin, + :atan, + :acsc, + :asec, + :acot, + :acosh, + :asinh, + :atanh, + :acsch, + :asech, + :acoth, +] + +for f in MATRIX_FUNCTIONS + @eval begin + function Base.$f( + a::AbstractNamedDimsArray, dimnames_codomain, dimnames_domain; kwargs... + ) + codomain = to_nameddimsindices(a, dimnames_codomain) + domain = to_nameddimsindices(a, dimnames_domain) + fa_unnamed = TensorAlgebra.$f( + dename(a), nameddimsindices(a), codomain, domain; kwargs... + ) + return nameddimsarray(fa_unnamed, (codomain..., domain...)) + end + end +end diff --git a/test/Project.toml b/test/Project.toml index dde0c35..03842dc 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -8,6 +8,7 @@ GradedArrays = "bc96ca6e-b7c8-4bb6-888e-c93f838762c2" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -16,11 +17,11 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Adapt = "4" Aqua = "0.8.9" BlockArrays = "1" -BlockSparseArrays = "0.8, 0.9, 0.10" +BlockSparseArrays = "0.10" Combinatorics = "1" GradedArrays = "0.4" NamedDimsArrays = "0.7" SafeTestsets = "0.1" Suppressor = "0.2" -TensorAlgebra = "0.3, 0.4" +TensorAlgebra = "0.4" Test = "1.10" diff --git a/test/test_tensoralgebra.jl b/test/test_tensoralgebra.jl index 21a5d7d..3d3c961 100644 --- a/test/test_tensoralgebra.jl +++ b/test/test_tensoralgebra.jl @@ -1,5 +1,6 @@ using LinearAlgebra: factorize, lq, norm, qr, svd -using NamedDimsArrays: dename, nameddimsindices, namedoneto +using NamedDimsArrays: NamedDimsArrays, dename, nameddimsindices, namedoneto +using StableRNGs: StableRNG using TensorAlgebra: TensorAlgebra, contract, @@ -14,6 +15,7 @@ using TensorAlgebra: right_polar, unmatricize using Test: @test, @testset, @test_broken + elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @testset "TensorAlgebra (eltype=$(elt))" for elt in elts @testset "contract" begin @@ -45,6 +47,21 @@ elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test dename(na_split, ("k", "i", "j", "l")) ≈ reshape(dename(na, ("a", "b")), (dename(k), dename(i), dename(j), dename(l))) end + @testset "Matrix functions" begin + for f in NamedDimsArrays.MATRIX_FUNCTIONS + f == :cbrt && elt <: Complex && continue + f == :cbrt && VERSION < v"1.11-" && continue + @eval begin + i, j, k, l = namedoneto.((2, 2, 2, 2), ("i", "j", "k", "l")) + rng = StableRNG(123) + a = randn(rng, $elt, (i, j, k, l)) + fa = $f(a, (j, l), (k, i)) + m = dename(matricize(a, (j, l) => "a", (k, i) => "b"), ("a", "b")) + fm = dename(matricize(fa, (j, l) => "a", (k, i) => "b"), ("a", "b")) + @test fm ≈ $f(m) + end + end + end @testset "qr/lq" begin dims = (2, 2, 2, 2) i, j, k, l = namedoneto.(dims, ("i", "j", "k", "l"))