Skip to content

Commit ea8a9a6

Browse files
authored
Define Delta type alias, TensorAlgebra.matricize (#33)
1 parent cb23366 commit ea8a9a6

File tree

6 files changed

+111
-24
lines changed

6 files changed

+111
-24
lines changed

Project.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DiagonalArrays"
22
uuid = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.10"
4+
version = "0.3.11"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
@@ -10,10 +10,17 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
1212

13+
[weakdeps]
14+
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
15+
16+
[extensions]
17+
DiagonalArraysTensorAlgebraExt = "TensorAlgebra"
18+
1319
[compat]
1420
ArrayLayouts = "1.10.4"
1521
DerivableInterfaces = "0.5"
1622
FillArrays = "1.13.0"
1723
LinearAlgebra = "1.10.0"
1824
SparseArraysBase = "0.7.2"
25+
TensorAlgebra = "0.3.10"
1926
julia = "1.10"
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
module DiagonalArraysTensorAlgebraExt
2+
3+
using DiagonalArrays: Delta
4+
using FillArrays: Eye
5+
using TensorAlgebra:
6+
TensorAlgebra,
7+
AbstractBlockPermutation,
8+
BlockedTrivialPermutation,
9+
BlockedTuple,
10+
FusionStyle,
11+
fuseaxes,
12+
matricize
13+
14+
struct DeltaFusion <: FusionStyle end
15+
TensorAlgebra.FusionStyle(::Delta) = DeltaFusion()
16+
function matricize_delta(a::AbstractArray, biperm::AbstractBlockPermutation{2})
17+
ax = fuseaxes(axes(a), biperm)
18+
return Eye{eltype(a)}(ax)
19+
end
20+
function TensorAlgebra.matricize(
21+
::DeltaFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2}
22+
)
23+
return matricize_delta(a, biperm)
24+
end
25+
function TensorAlgebra.matricize(
26+
::DeltaFusion, a::AbstractArray, biperm::BlockedTrivialPermutation{2}
27+
)
28+
return matricize_delta(a, biperm)
29+
end
30+
31+
function TensorAlgebra.unmatricize(::DeltaFusion, a::Eye, ax::BlockedTuple{2})
32+
length(a) == prod(length, ax) || throw(DimensionMismatch("reshape sizes don't match"))
33+
return Delta{eltype(a)}(Tuple(ax))
34+
end
35+
36+
end

src/diagonalarray/delta.jl

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,34 @@
1-
using FillArrays: Ones
1+
using FillArrays: Ones, OnesVector
2+
3+
const Delta{T,N,V<:OnesVector{T},Axes} = DiagonalArray{T,N,V,Axes}
4+
function Delta{T}(
5+
ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}
6+
) where {T}
7+
uniquelens = unique(length, ax)
8+
if !isone(length(uniquelens))
9+
throw(ArgumentError("All axes must have the same length for Delta."))
10+
end
11+
return DiagonalArray{T}(Ones{T}(only(uniquelens)), ax)
12+
end
13+
function Delta{T}(
14+
ax1::AbstractUnitRange{<:Integer}, ax_rest::AbstractUnitRange{<:Integer}...
15+
) where {T}
16+
return Delta{T}((ax1, ax_rest...))
17+
end
18+
function Delta{T}(sz::Tuple{Integer,Vararg{Integer}}) where {T}
19+
return Delta{T}(map(Base.OneTo, sz))
20+
end
21+
function Delta{T}(sz1::Integer, sz_rest::Integer...) where {T}
22+
return Delta{T}((sz1, sz_rest...))
23+
end
24+
function Delta{T}(ax::Tuple{}) where {T}
25+
return DiagonalArray{T}(Ones{T}(0), ax)
26+
end
227

328
function delta(
429
elt::Type, ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}
530
)
6-
return DiagonalArray(Ones{elt}(minimum(length, ax)), ax)
31+
return Delta{elt}(ax)
732
end
833
function δ(
934
elt::Type, ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}
@@ -35,7 +60,7 @@ function δ(ax1::AbstractUnitRange{<:Integer}, axs::AbstractUnitRange{<:Integer}
3560
end
3661

3762
function delta(elt::Type, size::Tuple{Vararg{Int}})
38-
return DiagonalArray(Ones{elt}(minimum(size)), size)
63+
return Delta{elt}(size)
3964
end
4065
function δ(elt::Type, size::Tuple{Vararg{Int}})
4166
return delta(elt, size)

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
66
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
77
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
88
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
9+
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
910
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1011

1112
[compat]
@@ -16,4 +17,5 @@ LinearAlgebra = "1"
1617
SafeTestsets = "0.1"
1718
SparseArraysBase = "0.7"
1819
Suppressor = "0.2"
20+
TensorAlgebra = "0.3.10"
1921
Test = "1"

test/test_basics.jl

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Test: @test, @testset, @test_broken, @inferred
22
using DiagonalArrays:
33
DiagonalArrays,
4+
Delta,
45
DiagonalArray,
56
DiagonalMatrix,
67
δ,
@@ -151,30 +152,33 @@ using LinearAlgebra: Diagonal
151152
end
152153
@testset "delta" begin
153154
for (a, elt′) in (
154-
(delta(2, 3), Float64),
155-
(delta(Base.OneTo(2), Base.OneTo(3)), Float64),
156-
(δ(2, 3), Float64),
157-
(δ(Base.OneTo(2), Base.OneTo(3)), Float64),
158-
(delta((2, 3)), Float64),
159-
(delta(Base.OneTo.((2, 3))), Float64),
160-
(δ((2, 3)), Float64),
161-
(δ(Base.OneTo.((2, 3))), Float64),
162-
(delta(Bool, 2, 3), Bool),
163-
(delta(Bool, Base.OneTo(2), Base.OneTo(3)), Bool),
164-
(δ(Bool, 2, 3), Bool),
165-
(δ(Bool, Base.OneTo(2), Base.OneTo(3)), Bool),
166-
(delta(Bool, (2, 3)), Bool),
167-
(delta(Bool, Base.OneTo.((2, 3))), Bool),
168-
(δ(Bool, (2, 3)), Bool),
169-
(δ(Bool, Base.OneTo.((2, 3))), Bool),
155+
(delta(2, 2), Float64),
156+
(delta(Base.OneTo(2), Base.OneTo(2)), Float64),
157+
(δ(2, 2), Float64),
158+
(δ(Base.OneTo(2), Base.OneTo(2)), Float64),
159+
(delta((2, 2)), Float64),
160+
(delta(Base.OneTo.((2, 2))), Float64),
161+
(δ((2, 2)), Float64),
162+
(δ(Base.OneTo.((2, 2))), Float64),
163+
(delta(Bool, 2, 2), Bool),
164+
(delta(Bool, Base.OneTo(2), Base.OneTo(2)), Bool),
165+
(Delta{Bool}((2, 2)), Bool),
166+
(Delta{Bool}(Base.OneTo.((2, 2))), Bool),
167+
(δ(Bool, 2, 2), Bool),
168+
(δ(Bool, Base.OneTo(2), Base.OneTo(2)), Bool),
169+
(delta(Bool, (2, 2)), Bool),
170+
(delta(Bool, Base.OneTo.((2, 2))), Bool),
171+
(δ(Bool, (2, 2)), Bool),
172+
(δ(Bool, Base.OneTo.((2, 2))), Bool),
170173
)
171174
@test eltype(a) === elt′
172175
@test diaglength(a) == 2
173-
@test a isa DiagonalArray{elt′}
174-
@test size(a) == (2, 3)
176+
@test a isa DiagonalArray{elt′,2}
177+
@test a isa Delta{elt′,2}
178+
@test size(a) == (2, 2)
175179
@test diaglength(a) == 2
176180
@test storedlength(a) == 2
177-
@test a == DiagonalArray(ones(2), (2, 3))
181+
@test a == DiagonalArray(ones(2), (2, 2))
178182
@test diagview(a) == ones(2)
179183
@test diagview(a) isa Ones{elt′}
180184

@@ -185,7 +189,7 @@ using LinearAlgebra: Diagonal
185189
# https://github.com/ITensor/DiagonalArrays.jl/issues/7
186190
@test_broken diagview(a′) isa Fill
187191

188-
b = randn(elt, (3, 4))
192+
b = randn(elt, (2, 3))
189193
a_dest = a * b
190194
@test a_dest Array(a) * Array(b)
191195
end

test/test_tensoralgebraext.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using DiagonalArrays: Delta
2+
using FillArrays: Eye
3+
using TensorAlgebra: FusionStyle, matricize, tuplemortar, unmatricize
4+
using Test: @test, @testset
5+
6+
@testset "matricize, unmatricize" begin
7+
a = Delta{Float32}(2, 2, 2)
8+
m = matricize(a, (1,), (2, 3))
9+
@test m Eye{Float32}(2, 4)
10+
@test unmatricize(
11+
FusionStyle(a), m, tuplemortar(((axes(a, 1),), (axes(a, 2), axes(a, 3))))
12+
) a
13+
end

0 commit comments

Comments
 (0)