From 69b2d4858d00308e648554ca7232b9616d9d290f Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 29 Jul 2025 16:34:41 -0400 Subject: [PATCH 1/3] Define Delta type alias, TensorAlgebra.matricize --- Project.toml | 9 ++++++++- src/diagonalarray/delta.jl | 31 ++++++++++++++++++++++++++++--- test/test_basics.jl | 6 +++++- 3 files changed, 41 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index c937ed9..23e53ac 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DiagonalArrays" uuid = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" authors = ["ITensor developers and contributors"] -version = "0.3.10" +version = "0.3.11" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" @@ -10,10 +10,17 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208" +[weakdeps] +TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" + +[extensions] +DiagonalArraysTensorAlgebraExt = "TensorAlgebra" + [compat] ArrayLayouts = "1.10.4" DerivableInterfaces = "0.5" FillArrays = "1.13.0" LinearAlgebra = "1.10.0" SparseArraysBase = "0.7.2" +TensorAlgebra = "0.3.10" julia = "1.10" diff --git a/src/diagonalarray/delta.jl b/src/diagonalarray/delta.jl index d4014dd..b2dc860 100644 --- a/src/diagonalarray/delta.jl +++ b/src/diagonalarray/delta.jl @@ -1,9 +1,34 @@ -using FillArrays: Ones +using FillArrays: Ones, OnesVector + +const Delta{T,N,V<:OnesVector{T},Axes} = DiagonalArray{T,N,V,Axes} +function Delta{T}( + ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}} +) where {T} + uniquelens = unique(length, ax) + if !isone(length(uniquelens)) + throw(ArgumentError("All axes must have the same length for Delta.")) + end + return DiagonalArray{T}(Ones{T}(only(uniquelens)), ax) +end +function Delta{T}( + ax1::AbstractUnitRange{<:Integer}, ax_rest::AbstractUnitRange{<:Integer}... +) where {T} + return Delta{T}((ax1, ax_rest...)) +end +function Delta{T}(sz::Tuple{Integer,Vararg{Integer}}) where {T} + return Delta{T}(map(Base.OneTo, sz)) +end +function Delta{T}(sz1::Integer, sz_rest::Integer...) where {T} + return Delta{T}((sz1, sz_rest...)) +end +function Delta{T}(ax::Tuple{}) where {T} + return DiagonalArray{T}(Ones{T}(0), ax) +end function delta( elt::Type, ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}} ) - return DiagonalArray(Ones{elt}(minimum(length, ax)), ax) + return Delta{elt}(ax) end function δ( elt::Type, ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}} @@ -35,7 +60,7 @@ function δ(ax1::AbstractUnitRange{<:Integer}, axs::AbstractUnitRange{<:Integer} end function delta(elt::Type, size::Tuple{Vararg{Int}}) - return DiagonalArray(Ones{elt}(minimum(size)), size) + return Delta{elt}(size) end function δ(elt::Type, size::Tuple{Vararg{Int}}) return delta(elt, size) diff --git a/test/test_basics.jl b/test/test_basics.jl index 70eceaf..026e408 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,6 +1,7 @@ using Test: @test, @testset, @test_broken, @inferred using DiagonalArrays: DiagonalArrays, + Delta, DiagonalArray, DiagonalMatrix, δ, @@ -161,6 +162,8 @@ using LinearAlgebra: Diagonal (δ(Base.OneTo.((2, 3))), Float64), (delta(Bool, 2, 3), Bool), (delta(Bool, Base.OneTo(2), Base.OneTo(3)), Bool), + (Delta{Bool}((2, 3)), Bool), + (Delta{Bool}(Base.OneTo.((2, 3))), Bool), (δ(Bool, 2, 3), Bool), (δ(Bool, Base.OneTo(2), Base.OneTo(3)), Bool), (delta(Bool, (2, 3)), Bool), @@ -170,7 +173,8 @@ using LinearAlgebra: Diagonal ) @test eltype(a) === elt′ @test diaglength(a) == 2 - @test a isa DiagonalArray{elt′} + @test a isa DiagonalArray{elt′,2} + @test a isa Delta{elt′,2} @test size(a) == (2, 3) @test diaglength(a) == 2 @test storedlength(a) == 2 From 4d71e63384a5381423c915f1f1c9ac5c682533b9 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 29 Jul 2025 16:43:41 -0400 Subject: [PATCH 2/3] Define Delta type alias, TensorAlgebra.matricize --- .../DiagonalArraysTensorAlgebraExt.jl | 36 ++++++++++++++++ test/Project.toml | 2 + test/test_basics.jl | 42 +++++++++---------- test/test_tensoralgebraext.jl | 11 +++++ 4 files changed, 70 insertions(+), 21 deletions(-) create mode 100644 ext/DiagonalArraysTensorAlgebraExt/DiagonalArraysTensorAlgebraExt.jl create mode 100644 test/test_tensoralgebraext.jl diff --git a/ext/DiagonalArraysTensorAlgebraExt/DiagonalArraysTensorAlgebraExt.jl b/ext/DiagonalArraysTensorAlgebraExt/DiagonalArraysTensorAlgebraExt.jl new file mode 100644 index 0000000..fc19745 --- /dev/null +++ b/ext/DiagonalArraysTensorAlgebraExt/DiagonalArraysTensorAlgebraExt.jl @@ -0,0 +1,36 @@ +module DiagonalArraysTensorAlgebraExt + +using DiagonalArrays: Delta +using FillArrays: Eye +using TensorAlgebra: + TensorAlgebra, + AbstractBlockPermutation, + BlockedTrivialPermutation, + BlockedTuple, + FusionStyle, + fuseaxes, + matricize + +struct DeltaFusion <: FusionStyle end +TensorAlgebra.FusionStyle(::Delta) = DeltaFusion() +function matricize_delta(a::AbstractArray, biperm::AbstractBlockPermutation{2}) + ax = fuseaxes(axes(a), biperm) + return Eye{eltype(a)}(ax) +end +function TensorAlgebra.matricize( + ::DeltaFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2} +) + return matricize_delta(a, biperm) +end +function TensorAlgebra.matricize( + ::DeltaFusion, a::AbstractArray, biperm::BlockedTrivialPermutation{2} +) + return matricize_delta(a, biperm) +end + +function TensorAlgebra.unmatricize(::DeltaFusion, a::Eye, ax::BlockedTuple{2}) + length(a) == prod(length, ax) || throw(DimensionMismatch("reshape sizes don't match")) + return Delta{eltype(a)}(Tuple(ax)) +end + +end diff --git a/test/Project.toml b/test/Project.toml index 279afa2..7ca368d 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,6 +6,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" +TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] @@ -16,4 +17,5 @@ LinearAlgebra = "1" SafeTestsets = "0.1" SparseArraysBase = "0.7" Suppressor = "0.2" +TensorAlgebra = "0.3.10" Test = "1" diff --git a/test/test_basics.jl b/test/test_basics.jl index 026e408..52263f0 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -152,33 +152,33 @@ using LinearAlgebra: Diagonal end @testset "delta" begin for (a, elt′) in ( - (delta(2, 3), Float64), - (delta(Base.OneTo(2), Base.OneTo(3)), Float64), - (δ(2, 3), Float64), - (δ(Base.OneTo(2), Base.OneTo(3)), Float64), - (delta((2, 3)), Float64), - (delta(Base.OneTo.((2, 3))), Float64), - (δ((2, 3)), Float64), - (δ(Base.OneTo.((2, 3))), Float64), - (delta(Bool, 2, 3), Bool), - (delta(Bool, Base.OneTo(2), Base.OneTo(3)), Bool), - (Delta{Bool}((2, 3)), Bool), - (Delta{Bool}(Base.OneTo.((2, 3))), Bool), - (δ(Bool, 2, 3), Bool), - (δ(Bool, Base.OneTo(2), Base.OneTo(3)), Bool), - (delta(Bool, (2, 3)), Bool), - (delta(Bool, Base.OneTo.((2, 3))), Bool), - (δ(Bool, (2, 3)), Bool), - (δ(Bool, Base.OneTo.((2, 3))), Bool), + (delta(2, 2), Float64), + (delta(Base.OneTo(2), Base.OneTo(2)), Float64), + (δ(2, 2), Float64), + (δ(Base.OneTo(2), Base.OneTo(2)), Float64), + (delta((2, 2)), Float64), + (delta(Base.OneTo.((2, 2))), Float64), + (δ((2, 2)), Float64), + (δ(Base.OneTo.((2, 2))), Float64), + (delta(Bool, 2, 2), Bool), + (delta(Bool, Base.OneTo(2), Base.OneTo(2)), Bool), + (Delta{Bool}((2, 2)), Bool), + (Delta{Bool}(Base.OneTo.((2, 2))), Bool), + (δ(Bool, 2, 2), Bool), + (δ(Bool, Base.OneTo(2), Base.OneTo(2)), Bool), + (delta(Bool, (2, 2)), Bool), + (delta(Bool, Base.OneTo.((2, 2))), Bool), + (δ(Bool, (2, 2)), Bool), + (δ(Bool, Base.OneTo.((2, 2))), Bool), ) @test eltype(a) === elt′ @test diaglength(a) == 2 @test a isa DiagonalArray{elt′,2} @test a isa Delta{elt′,2} - @test size(a) == (2, 3) + @test size(a) == (2, 2) @test diaglength(a) == 2 @test storedlength(a) == 2 - @test a == DiagonalArray(ones(2), (2, 3)) + @test a == DiagonalArray(ones(2), (2, 2)) @test diagview(a) == ones(2) @test diagview(a) isa Ones{elt′} @@ -189,7 +189,7 @@ using LinearAlgebra: Diagonal # https://github.com/ITensor/DiagonalArrays.jl/issues/7 @test_broken diagview(a′) isa Fill - b = randn(elt, (3, 4)) + b = randn(elt, (2, 3)) a_dest = a * b @test a_dest ≈ Array(a) * Array(b) end diff --git a/test/test_tensoralgebraext.jl b/test/test_tensoralgebraext.jl new file mode 100644 index 0000000..c9ed8ea --- /dev/null +++ b/test/test_tensoralgebraext.jl @@ -0,0 +1,11 @@ +using DiagonalArrays: Delta +using FillArrays: Eye +using TensorAlgebra: FusionStyle, matricize, tuplemortar, unmatricize +using Test: @test, @testset + +@testset "matricize, unmatricize" begin + a = Delta{Float32}(2, 2, 2) + m = matricize(a, (1,), (2, 3)) + @test m ≡ Eye{Float32}(2, 4) + @test unmatricize(FusionStyle(a), m, tuplemortar(((axes(a, 1),), (axes(a, 2), axes(a, 3))))) ≡ a +end From 2d73ebad0fcbc7937aa10c9f65d8b2208a089bd0 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 29 Jul 2025 16:54:14 -0400 Subject: [PATCH 3/3] Format --- test/test_tensoralgebraext.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/test_tensoralgebraext.jl b/test/test_tensoralgebraext.jl index c9ed8ea..02c0804 100644 --- a/test/test_tensoralgebraext.jl +++ b/test/test_tensoralgebraext.jl @@ -7,5 +7,7 @@ using Test: @test, @testset a = Delta{Float32}(2, 2, 2) m = matricize(a, (1,), (2, 3)) @test m ≡ Eye{Float32}(2, 4) - @test unmatricize(FusionStyle(a), m, tuplemortar(((axes(a, 1),), (axes(a, 2), axes(a, 3))))) ≡ a + @test unmatricize( + FusionStyle(a), m, tuplemortar(((axes(a, 1),), (axes(a, 2), axes(a, 3)))) + ) ≡ a end