diff --git a/Project.toml b/Project.toml index c0b41e7..7a85171 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.4.0" +version = "0.4.1" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/src/TensorAlgebra.jl b/src/TensorAlgebra.jl index ee32607..125f9b7 100644 --- a/src/TensorAlgebra.jl +++ b/src/TensorAlgebra.jl @@ -29,5 +29,6 @@ include("contract/blockedperms.jl") include("contract/allocate_output.jl") include("contract/contract_matricize/contract.jl") include("factorizations.jl") +include("matrixfunctions.jl") end diff --git a/src/matrixfunctions.jl b/src/matrixfunctions.jl new file mode 100644 index 0000000..871769e --- /dev/null +++ b/src/matrixfunctions.jl @@ -0,0 +1,46 @@ +# TensorAlgebra version of matrix functions. +const MATRIX_FUNCTIONS = [ + :exp, + :cis, + :log, + :sqrt, + :cbrt, + :cos, + :sin, + :tan, + :csc, + :sec, + :cot, + :cosh, + :sinh, + :tanh, + :csch, + :sech, + :coth, + :acos, + :asin, + :atan, + :acsc, + :asec, + :acot, + :acosh, + :asinh, + :atanh, + :acsch, + :asech, + :acoth, +] + +for f in MATRIX_FUNCTIONS + @eval begin + function $f(a::AbstractArray, labels_a, labels_codomain, labels_domain; kwargs...) + biperm = blockedperm_indexin(Tuple.((labels_a, labels_codomain, labels_domain))...) + return $f(a, biperm; kwargs...) + end + function $f(a::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) + a_mat = matricize(a, biperm) + fa_mat = Base.$f(a_mat; kwargs...) + return unmatricize(fa_mat, axes(a)[biperm]) + end + end +end diff --git a/test/test_factorizations.jl b/test/test_factorizations.jl index efb81fc..f23f38d 100644 --- a/test/test_factorizations.jl +++ b/test/test_factorizations.jl @@ -1,8 +1,4 @@ using LinearAlgebra: LinearAlgebra, norm, diag -using Test: @test, @testset - -using TestExtras: @constinferred - using MatrixAlgebraKit: truncrank using TensorAlgebra: contract, @@ -21,6 +17,8 @@ using TensorAlgebra: right_polar, svd, svdvals +using Test: @test, @testset +using TestExtras: @constinferred elts = (Float64, ComplexF64) diff --git a/test/test_matrixfunctions.jl b/test/test_matrixfunctions.jl new file mode 100644 index 0000000..3e35e3c --- /dev/null +++ b/test/test_matrixfunctions.jl @@ -0,0 +1,21 @@ +using StableRNGs: StableRNG +using TensorAlgebra: TensorAlgebra, biperm +using Test: @test, @testset + +@testset "Matrix functions (eltype=$elt)" for elt in (Float32, ComplexF64) + for f in TensorAlgebra.MATRIX_FUNCTIONS + f == :cbrt && elt <: Complex && continue + f == :cbrt && VERSION < v"1.11-" && continue + @eval begin + rng = StableRNG(123) + a = randn(rng, $elt, (2, 2, 2, 2)) + for fa in ( + TensorAlgebra.$f(a, (:a, :b, :c, :d), (:c, :b), (:d, :a)), + TensorAlgebra.$f(a, biperm((3, 2, 4, 1), Val(2))), + ) + fa′ = reshape($f(reshape(permutedims(a, (3, 2, 4, 1)), (4, 4))), (2, 2, 2, 2)) + @test fa ≈ fa′ + end + end + end +end