Skip to content

Commit e6df7a7

Browse files
committed
More missing functionality
1 parent 8500390 commit e6df7a7

File tree

3 files changed

+59
-5
lines changed

3 files changed

+59
-5
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ version = "0.1.2"
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
88
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
99
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
10+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261"
1112
NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
1213
UnallocatedArrays = "43c9e47c-e622-40fb-bf18-a09fc8c466b6"
@@ -16,6 +17,7 @@ UnspecifiedTypes = "42b3faec-625b-4613-8ddc-352bf9672b8d"
1617
Accessors = "0.1.39"
1718
DerivableInterfaces = "0.3.7"
1819
FillArrays = "1.13.0"
20+
LinearAlgebra = "1.11.0"
1921
MapBroadcast = "0.1.5"
2022
NamedDimsArrays = "0.3.0"
2123
UnallocatedArrays = "0.1.1"

src/ITensorBase.jl

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ using NamedDimsArrays:
1818
named,
1919
nameddimsindices,
2020
setname,
21+
setnameddimsindices,
2122
unname
2223

2324
const Tag = String
@@ -79,7 +80,9 @@ struct Index{T,Value<:AbstractUnitRange{T}} <: AbstractNamedUnitRange{T,Value,In
7980
name::IndexName
8081
end
8182

82-
Index(length::Int; kwargs...) = Index(Base.OneTo(length), IndexName(; kwargs...))
83+
function Index(length::Int; tags, kwargs...)
84+
return Index(Base.OneTo(length), IndexName(; tags=tagset(tags), kwargs...))
85+
end
8386
function Index(length::Int, tags::String; kwargs...)
8487
return Index(Base.OneTo(length), IndexName(; kwargs..., tags=tagset(tags)))
8588
end
@@ -177,17 +180,30 @@ struct AllocatableArrayInterface <: AbstractAllocatableArrayInterface end
177180

178181
unallocatable(a::AbstractITensor) = NamedDimsArray(a)
179182

180-
@interface ::AbstractAllocatableArrayInterface function Base.setindex!(
181-
a::AbstractArray, value, I::Int...
182-
)
183+
function setindex_allocatable!(a::AbstractArray, value, I...)
183184
allocate!(specify_eltype!(a, typeof(value)))
184185
# TODO: Maybe use `@interface interface(a) a[I...] = value`?
185186
unallocatable(a)[I...] = value
186187
return a
187188
end
188189

190+
# TODO: Combine these by using `Base.to_indices`.
191+
@interface ::AbstractAllocatableArrayInterface function Base.setindex!(
192+
a::AbstractArray, value, I::Int...
193+
)
194+
setindex_allocatable!(a, value, I...)
195+
return a
196+
end
197+
@interface ::AbstractAllocatableArrayInterface function Base.setindex!(
198+
a::AbstractArray, value, I::AbstractNamedInteger...
199+
)
200+
setindex_allocatable!(a, value, I...)
201+
return a
202+
end
203+
189204
@derive AllocatableArrayInterface() (T=AbstractITensor,) begin
190205
Base.setindex!(::T, ::Any, ::Int...)
206+
Base.setindex!(::T, ::Any, ::AbstractNamedInteger...)
191207
end
192208

193209
mutable struct ITensor <: AbstractITensor
@@ -216,6 +232,25 @@ function ITensor()
216232
return ITensor(Zeros{UnspecifiedZero}(), ())
217233
end
218234

235+
inds(a::AbstractITensor) = nameddimsindices(a)
236+
setinds(a::AbstractITensor, inds) = setnameddimsindices(a, inds)
237+
238+
function uniqueinds(a1::AbstractITensor, a_rest::AbstractITensor...)
239+
return setdiff(inds(a1), inds.(a_rest)...)
240+
end
241+
function uniqueind(a1::AbstractITensor, a_rest::AbstractITensor...)
242+
return only(uniqueinds(a1, a_rest...))
243+
end
244+
245+
function commoninds(a1::AbstractITensor, a_rest::AbstractITensor...)
246+
return intersect(inds(a1), inds.(a_rest)...)
247+
end
248+
function commonind(a1::AbstractITensor, a_rest::AbstractITensor...)
249+
return only(commoninds(a1, a_rest...))
250+
end
251+
252+
prime(a::AbstractITensor) = setinds(a, prime.(inds(a)))
253+
219254
include("quirks.jl")
220255

221256
end

src/quirks.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,28 @@ dag(i::Index) = i
44
dim(i::Index) = dename(length(i))
55
# TODO: Define this properly.
66
hasqns(i::Index) = false
7-
inds(a::ITensor) = nameddimsindices(a)
87
# TODO: Deprecate.
98
itensor(parent::AbstractArray, nameddimsindices) = ITensor(parent, nameddimsindices)
109
function itensor(parent::AbstractArray, i1::Index, i_rest::Index...)
1110
return ITensor(parent, (i1, i_rest...))
1211
end
1312

13+
# This seems to be needed to get broadcasting working.
14+
# TODO: Investigate this and see if we can get rid of it.
1415
Base.Broadcast.extrude(a::AbstractITensor) = a
16+
17+
# TODO: Generalize this.
18+
# Maybe define it as `oneelement`, and base it on
19+
# `FillArrays.OneElement` (https://juliaarrays.github.io/FillArrays.jl/stable/#FillArrays.OneElement).
20+
function onehot(iv::Pair{<:Index,<:Int})
21+
a = ITensor(first(iv))
22+
a[last(iv)] = one(Bool)
23+
return a
24+
end
25+
26+
using LinearAlgebra: svd
27+
# TODO: Define this in `MatrixAlgebra.jl`/`TensorAlgebra.jl`.
28+
function factorize(a::AbstractITensor, args...; kwargs...)
29+
U, S, V = svd(a, args...; kwargs...)
30+
return U, S * V
31+
end

0 commit comments

Comments
 (0)