Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DiagonalArrays"
uuid = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.3.10"
version = "0.3.11"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand All @@ -10,10 +10,17 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"

[weakdeps]
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"

[extensions]
DiagonalArraysTensorAlgebraExt = "TensorAlgebra"

[compat]
ArrayLayouts = "1.10.4"
DerivableInterfaces = "0.5"
FillArrays = "1.13.0"
LinearAlgebra = "1.10.0"
SparseArraysBase = "0.7.2"
TensorAlgebra = "0.3.10"
julia = "1.10"
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
module DiagonalArraysTensorAlgebraExt

using DiagonalArrays: Delta
using FillArrays: Eye
using TensorAlgebra:
TensorAlgebra,
AbstractBlockPermutation,
BlockedTrivialPermutation,
BlockedTuple,
FusionStyle,
fuseaxes,
matricize

struct DeltaFusion <: FusionStyle end
TensorAlgebra.FusionStyle(::Delta) = DeltaFusion()
function matricize_delta(a::AbstractArray, biperm::AbstractBlockPermutation{2})
ax = fuseaxes(axes(a), biperm)
return Eye{eltype(a)}(ax)
end
function TensorAlgebra.matricize(
::DeltaFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2}
)
return matricize_delta(a, biperm)
end
function TensorAlgebra.matricize(
::DeltaFusion, a::AbstractArray, biperm::BlockedTrivialPermutation{2}
)
return matricize_delta(a, biperm)
end

function TensorAlgebra.unmatricize(::DeltaFusion, a::Eye, ax::BlockedTuple{2})
length(a) == prod(length, ax) || throw(DimensionMismatch("reshape sizes don't match"))
return Delta{eltype(a)}(Tuple(ax))
end

end
31 changes: 28 additions & 3 deletions src/diagonalarray/delta.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,34 @@
using FillArrays: Ones
using FillArrays: Ones, OnesVector

const Delta{T,N,V<:OnesVector{T},Axes} = DiagonalArray{T,N,V,Axes}
function Delta{T}(
ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}
) where {T}
uniquelens = unique(length, ax)
if !isone(length(uniquelens))
throw(ArgumentError("All axes must have the same length for Delta."))
end
return DiagonalArray{T}(Ones{T}(only(uniquelens)), ax)
end
function Delta{T}(
ax1::AbstractUnitRange{<:Integer}, ax_rest::AbstractUnitRange{<:Integer}...
) where {T}
return Delta{T}((ax1, ax_rest...))
end
function Delta{T}(sz::Tuple{Integer,Vararg{Integer}}) where {T}
return Delta{T}(map(Base.OneTo, sz))
end
function Delta{T}(sz1::Integer, sz_rest::Integer...) where {T}
return Delta{T}((sz1, sz_rest...))
end
function Delta{T}(ax::Tuple{}) where {T}
return DiagonalArray{T}(Ones{T}(0), ax)
end

function delta(
elt::Type, ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}
)
return DiagonalArray(Ones{elt}(minimum(length, ax)), ax)
return Delta{elt}(ax)
end
function δ(
elt::Type, ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}
Expand Down Expand Up @@ -35,7 +60,7 @@ function δ(ax1::AbstractUnitRange{<:Integer}, axs::AbstractUnitRange{<:Integer}
end

function delta(elt::Type, size::Tuple{Vararg{Int}})
return DiagonalArray(Ones{elt}(minimum(size)), size)
return Delta{elt}(size)
end
function δ(elt::Type, size::Tuple{Vararg{Int}})
return delta(elt, size)
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
Expand All @@ -16,4 +17,5 @@ LinearAlgebra = "1"
SafeTestsets = "0.1"
SparseArraysBase = "0.7"
Suppressor = "0.2"
TensorAlgebra = "0.3.10"
Test = "1"
44 changes: 24 additions & 20 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Test: @test, @testset, @test_broken, @inferred
using DiagonalArrays:
DiagonalArrays,
Delta,
DiagonalArray,
DiagonalMatrix,
δ,
Expand Down Expand Up @@ -151,30 +152,33 @@ using LinearAlgebra: Diagonal
end
@testset "delta" begin
for (a, elt′) in (
(delta(2, 3), Float64),
(delta(Base.OneTo(2), Base.OneTo(3)), Float64),
(δ(2, 3), Float64),
(δ(Base.OneTo(2), Base.OneTo(3)), Float64),
(delta((2, 3)), Float64),
(delta(Base.OneTo.((2, 3))), Float64),
(δ((2, 3)), Float64),
(δ(Base.OneTo.((2, 3))), Float64),
(delta(Bool, 2, 3), Bool),
(delta(Bool, Base.OneTo(2), Base.OneTo(3)), Bool),
(δ(Bool, 2, 3), Bool),
(δ(Bool, Base.OneTo(2), Base.OneTo(3)), Bool),
(delta(Bool, (2, 3)), Bool),
(delta(Bool, Base.OneTo.((2, 3))), Bool),
(δ(Bool, (2, 3)), Bool),
(δ(Bool, Base.OneTo.((2, 3))), Bool),
(delta(2, 2), Float64),
(delta(Base.OneTo(2), Base.OneTo(2)), Float64),
(δ(2, 2), Float64),
(δ(Base.OneTo(2), Base.OneTo(2)), Float64),
(delta((2, 2)), Float64),
(delta(Base.OneTo.((2, 2))), Float64),
(δ((2, 2)), Float64),
(δ(Base.OneTo.((2, 2))), Float64),
(delta(Bool, 2, 2), Bool),
(delta(Bool, Base.OneTo(2), Base.OneTo(2)), Bool),
(Delta{Bool}((2, 2)), Bool),
(Delta{Bool}(Base.OneTo.((2, 2))), Bool),
(δ(Bool, 2, 2), Bool),
(δ(Bool, Base.OneTo(2), Base.OneTo(2)), Bool),
(delta(Bool, (2, 2)), Bool),
(delta(Bool, Base.OneTo.((2, 2))), Bool),
(δ(Bool, (2, 2)), Bool),
(δ(Bool, Base.OneTo.((2, 2))), Bool),
)
@test eltype(a) === elt′
@test diaglength(a) == 2
@test a isa DiagonalArray{elt′}
@test size(a) == (2, 3)
@test a isa DiagonalArray{elt′,2}
@test a isa Delta{elt′,2}
@test size(a) == (2, 2)
@test diaglength(a) == 2
@test storedlength(a) == 2
@test a == DiagonalArray(ones(2), (2, 3))
@test a == DiagonalArray(ones(2), (2, 2))
@test diagview(a) == ones(2)
@test diagview(a) isa Ones{elt′}

Expand All @@ -185,7 +189,7 @@ using LinearAlgebra: Diagonal
# https://github.com/ITensor/DiagonalArrays.jl/issues/7
@test_broken diagview(a′) isa Fill

b = randn(elt, (3, 4))
b = randn(elt, (2, 3))
a_dest = a * b
@test a_dest ≈ Array(a) * Array(b)
end
Expand Down
13 changes: 13 additions & 0 deletions test/test_tensoralgebraext.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using DiagonalArrays: Delta
using FillArrays: Eye
using TensorAlgebra: FusionStyle, matricize, tuplemortar, unmatricize
using Test: @test, @testset

@testset "matricize, unmatricize" begin
a = Delta{Float32}(2, 2, 2)
m = matricize(a, (1,), (2, 3))
@test m ≡ Eye{Float32}(2, 4)
@test unmatricize(
FusionStyle(a), m, tuplemortar(((axes(a, 1),), (axes(a, 2), axes(a, 3))))
) ≡ a
end
Loading