Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.3.0"
version = "0.3.1"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand Down
25 changes: 5 additions & 20 deletions src/matricize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,31 +38,16 @@

# TODO remove _permutedims once support for Julia 1.10 is dropped
# define permutedims with a BlockedPermuation. Default is to flatten it.
function Base.permutedims(a::AbstractArray, biperm::AbstractBlockPermutation)
function blockpermutedims(a::AbstractArray, biperm::AbstractBlockPermutation)

Check warning on line 41 in src/matricize.jl

View check run for this annotation

Codecov / codecov/patch

src/matricize.jl#L41

Added line #L41 was not covered by tests
return _permutedims(a, Tuple(biperm))
end

# solve ambiguities
function Base.permutedims(a::StridedArray, biperm::AbstractBlockPermutation)
return _permutedims(a, Tuple(biperm))
end
function Base.permutedims(a::Diagonal, biperm::AbstractBlockPermutation)
return _permutedims(a, Tuple(biperm))
end

function Base.permutedims!(
function blockpermutedims!(

Check warning on line 45 in src/matricize.jl

View check run for this annotation

Codecov / codecov/patch

src/matricize.jl#L45

Added line #L45 was not covered by tests
a::AbstractArray, b::AbstractArray, biperm::AbstractBlockPermutation
)
return _permutedims!(a, b, Tuple(biperm))
end

# solve ambiguities
function Base.permutedims!(
a::Array{T,N}, b::StridedArray{T,N}, biperm::AbstractBlockPermutation
) where {T,N}
return _permutedims!(a, b, Tuple(biperm))
end

# ===================================== matricize ========================================
# TBD settle copy/not copy convention
# matrix factorizations assume copy
Expand All @@ -75,7 +60,7 @@
function matricize(
style::FusionStyle, a::AbstractArray, biperm::AbstractBlockPermutation{2}
)
a_perm = permutedims(a, biperm)
a_perm = blockpermutedims(a, biperm)

Check warning on line 63 in src/matricize.jl

View check run for this annotation

Codecov / codecov/patch

src/matricize.jl#L63

Added line #L63 was not covered by tests
return matricize(style, a_perm, trivialperm(biperm))
end

Expand Down Expand Up @@ -112,7 +97,7 @@
)
blocked_axes = axes[biperm]
a_perm = unmatricize(m, blocked_axes)
return permutedims(a_perm, invperm(biperm))
return blockpermutedims(a_perm, invperm(biperm))

Check warning on line 100 in src/matricize.jl

View check run for this annotation

Codecov / codecov/patch

src/matricize.jl#L100

Added line #L100 was not covered by tests
end

function unmatricize(::ReshapeFusion, m::AbstractMatrix, axes::AbstractUnitRange...)
Expand Down Expand Up @@ -147,5 +132,5 @@
)
blocked_axes = axes(a)[biperm]
a_perm = unmatricize(m, blocked_axes)
return permutedims!(a, a_perm, invperm(biperm))
return blockpermutedims!(a, a_perm, invperm(biperm))

Check warning on line 135 in src/matricize.jl

View check run for this annotation

Codecov / codecov/patch

src/matricize.jl#L135

Added line #L135 was not covered by tests
end
20 changes: 19 additions & 1 deletion test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,30 @@ using StableRNGs: StableRNG
using TensorOperations: TensorOperations

using TensorAlgebra:
blockedpermvcat, contract, contract!, matricize, tuplemortar, unmatricize, unmatricize!
blockedpermvcat,
blockpermutedims,
blockpermutedims!,
contract,
contract!,
matricize,
tuplemortar,
unmatricize,
unmatricize!

default_rtol(elt::Type) = 10^(0.75 * log10(eps(real(elt))))
const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})

@testset "TensorAlgebra" begin
@testset "blockpermutedims (eltype=$elt)" for elt in elts
a = randn(elt, 2, 3, 4, 5)
a_perm = blockpermutedims(a, blockedpermvcat((3, 1), (2, 4)))
@test a_perm == permutedims(a, (3, 1, 2, 4))

a = randn(elt, 2, 3, 4, 5)
a_perm = Array{elt}(undef, (4, 2, 3, 5))
blockpermutedims!(a_perm, a, blockedpermvcat((3, 1), (2, 4)))
@test a_perm == permutedims(a, (3, 1, 2, 4))
end
@testset "matricize (eltype=$elt)" for elt in elts
a = randn(elt, 2, 3, 4, 5)

Expand Down
Loading