Skip to content

Commit baf0d2f

Browse files
committed
Implement oneelement in SparseArraysBaseExt
1 parent 5576de1 commit baf0d2f

File tree

5 files changed

+36
-37
lines changed

5 files changed

+36
-37
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@ VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
1616

1717
[weakdeps]
1818
DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
19+
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
1920

2021
[extensions]
2122
ITensorBaseDiagonalArraysExt = "DiagonalArrays"
23+
ITensorBaseSparseArraysBaseExt = ["NamedDimsArrays", "SparseArraysBase"]
2224

2325
[compat]
2426
Accessors = "0.1.39"
@@ -28,6 +30,7 @@ FillArrays = "1.13.0"
2830
LinearAlgebra = "1.10"
2931
MapBroadcast = "0.1.5"
3032
NamedDimsArrays = "0.4"
33+
SparseArraysBase = "0.2.10"
3134
UnallocatedArrays = "0.1.1"
3235
UnspecifiedTypes = "0.1.1"
3336
VectorInterface = "0.5.0"
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
module ITensorBaseSparseArraysBaseExt
2+
3+
using ITensorBase: ITensor, Index
4+
using NamedDimsArrays: dename
5+
using SparseArraysBase: SparseArraysBase, oneelement
6+
7+
function SparseArraysBase.oneelement(
8+
value, index::NTuple{N,Int}, ax::NTuple{N,Index}
9+
) where {N}
10+
return ITensor(oneelement(value, index, only.(axes.(dename.(ax)))), ax)
11+
end
12+
13+
end

src/abstractitensor.jl

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -91,24 +91,6 @@ function ITensor(parent::AbstractArray)
9191
return ITensor(parent, ())
9292
end
9393

94-
# TODO:
95-
# 1. Generalize this to arbitrary dimensions.
96-
# 2. Define a basic non-ITensor version in `SparseArraysBase.jl`,
97-
# and this constructor can wrap that one. It could construct
98-
# a `OneElement` sparse object like `FillArrays.OneElement`
99-
# (https://juliaarrays.github.io/FillArrays.jl/stable/#FillArrays.OneElement).
100-
# 3. Define `oneelement(value, index::Tuple{Vargarg{Int}}, axes::Tuple{Vararg{Index}})`,
101-
# where the pair version calls out to that one.
102-
function oneelement(elt::Type{<:Number}, iv::Pair{<:Index,<:Int})
103-
a = ITensor(first(iv))
104-
a[last(iv)] = one(elt)
105-
return a
106-
end
107-
# TODO: The non-ITensor version should default to `Float64`,
108-
# like `FillArrays.OneElement`.
109-
# TODO: Make the element type `UnspecifiedOne`.
110-
oneelement(iv::Pair{<:Index,<:Int}) = oneelement(Bool, iv)
111-
11294
using Accessors: @set
11395
setdenamed(a::ITensor, denamed) = (@set a.parent = denamed)
11496
setdenamed!(a::ITensor, denamed) = (a.parent = denamed)

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ ITensorBase = "4795dd04-0d67-49bb-8f44-b89c448a1dc7"
55
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
66
NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
77
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
8+
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
89
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
910
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1011

test/test_basics.jl

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@ using ITensorBase:
55
gettag,
66
hastag,
77
inds,
8-
oneelement,
98
plev,
109
prime,
1110
settag,
1211
tags,
1312
unsettag
1413
using DiagonalArrays: δ, delta, diagview
1514
using NamedDimsArrays: dename, name, named
15+
using SparseArraysBase: oneelement
1616
using Test: @test, @test_broken, @testset
1717

1818
@testset "ITensorBase" begin
@@ -46,6 +46,24 @@ using Test: @test, @test_broken, @testset
4646
@test plev(i) == 0
4747
@test length(tags(i)) == 0
4848
end
49+
@testset "delta" begin
50+
i, j = Index.((2, 2))
51+
for a in (
52+
delta(i, j),
53+
delta(Bool, i, j),
54+
delta((i, j)),
55+
delta(Bool, (i, j)),
56+
δ(i, j),
57+
δ(Bool, i, j),
58+
δ((i, j)),
59+
δ(Bool, (i, j)),
60+
)
61+
@test eltype(a) === Bool
62+
# TODO: Fix this.
63+
@test_broken diagview(a)
64+
@test diagview(dename(a)) == ones(2)
65+
end
66+
end
4967
@testset "oneelement" begin
5068
i = Index(3)
5169
a = oneelement(i => 2)
@@ -67,22 +85,4 @@ using Test: @test, @test_broken, @testset
6785
@test a[2] == 1
6886
@test a[3] == 0
6987
end
70-
@testset "delta" begin
71-
i, j = Index.((2, 2))
72-
for a in (
73-
delta(i, j),
74-
delta(Bool, i, j),
75-
delta((i, j)),
76-
delta(Bool, (i, j)),
77-
δ(i, j),
78-
δ(Bool, i, j),
79-
δ((i, j)),
80-
δ(Bool, (i, j)),
81-
)
82-
@test eltype(a) === Bool
83-
# TODO: Fix this.
84-
@test_broken diagview(a)
85-
@test diagview(dename(a)) == ones(2)
86-
end
87-
end
8888
end

0 commit comments

Comments
 (0)