diff --git a/Project.toml b/Project.toml index 288b1e9..b285dd2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.3.0" +version = "0.3.1" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/src/matricize.jl b/src/matricize.jl index 49bc949..fa71531 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -38,31 +38,16 @@ end # TODO remove _permutedims once support for Julia 1.10 is dropped # define permutedims with a BlockedPermuation. Default is to flatten it. -function Base.permutedims(a::AbstractArray, biperm::AbstractBlockPermutation) +function permuteblockeddims(a::AbstractArray, biperm::AbstractBlockPermutation) return _permutedims(a, Tuple(biperm)) end -# solve ambiguities -function Base.permutedims(a::StridedArray, biperm::AbstractBlockPermutation) - return _permutedims(a, Tuple(biperm)) -end -function Base.permutedims(a::Diagonal, biperm::AbstractBlockPermutation) - return _permutedims(a, Tuple(biperm)) -end - -function Base.permutedims!( +function permuteblockeddims!( a::AbstractArray, b::AbstractArray, biperm::AbstractBlockPermutation ) return _permutedims!(a, b, Tuple(biperm)) end -# solve ambiguities -function Base.permutedims!( - a::Array{T,N}, b::StridedArray{T,N}, biperm::AbstractBlockPermutation -) where {T,N} - return _permutedims!(a, b, Tuple(biperm)) -end - # ===================================== matricize ======================================== # TBD settle copy/not copy convention # matrix factorizations assume copy @@ -75,7 +60,7 @@ end function matricize( style::FusionStyle, a::AbstractArray, biperm::AbstractBlockPermutation{2} ) - a_perm = permutedims(a, biperm) + a_perm = permuteblockeddims(a, biperm) return matricize(style, a_perm, trivialperm(biperm)) end @@ -112,7 +97,7 @@ function unmatricize( ) blocked_axes = axes[biperm] a_perm = unmatricize(m, blocked_axes) - return permutedims(a_perm, invperm(biperm)) + return permuteblockeddims(a_perm, invperm(biperm)) end function unmatricize(::ReshapeFusion, m::AbstractMatrix, axes::AbstractUnitRange...) @@ -147,5 +132,5 @@ function unmatricize!( ) blocked_axes = axes(a)[biperm] a_perm = unmatricize(m, blocked_axes) - return permutedims!(a, a_perm, invperm(biperm)) + return permuteblockeddims!(a, a_perm, invperm(biperm)) end diff --git a/test/test_basics.jl b/test/test_basics.jl index fc11630..53d48a9 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -5,12 +5,30 @@ using StableRNGs: StableRNG using TensorOperations: TensorOperations using TensorAlgebra: - blockedpermvcat, contract, contract!, matricize, tuplemortar, unmatricize, unmatricize! + blockedpermvcat, + permuteblockeddims, + permuteblockeddims!, + contract, + contract!, + matricize, + tuplemortar, + unmatricize, + unmatricize! default_rtol(elt::Type) = 10^(0.75 * log10(eps(real(elt)))) const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @testset "TensorAlgebra" begin + @testset "permuteblockeddims (eltype=$elt)" for elt in elts + a = randn(elt, 2, 3, 4, 5) + a_perm = permuteblockeddims(a, blockedpermvcat((3, 1), (2, 4))) + @test a_perm == permutedims(a, (3, 1, 2, 4)) + + a = randn(elt, 2, 3, 4, 5) + a_perm = Array{elt}(undef, (4, 2, 3, 5)) + permuteblockeddims!(a_perm, a, blockedpermvcat((3, 1), (2, 4))) + @test a_perm == permutedims(a, (3, 1, 2, 4)) + end @testset "matricize (eltype=$elt)" for elt in elts a = randn(elt, 2, 3, 4, 5)