Skip to content

Commit a0aefe6

Browse files
committed
Allow custom axes when constructing operators, ITensor extension
1 parent 3df5640 commit a0aefe6

File tree

5 files changed

+55
-19
lines changed

5 files changed

+55
-19
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
99

1010
[weakdeps]
1111
ITensorBase = "4795dd04-0d67-49bb-8f44-b89c448a1dc7"
12+
NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
1213

1314
[extensions]
14-
QuantumOperatorDefinitionsITensorBaseExt = "ITensorBase"
15+
QuantumOperatorDefinitionsITensorBaseExt = ["ITensorBase", "NamedDimsArrays"]
1516

1617
[compat]
1718
ITensorBase = "0.1.10"
1819
LinearAlgebra = "1.10"
20+
NamedDimsArrays = "0.4.0"
1921
Random = "1.10"
2022
julia = "1.10"

ext/QuantumOperatorDefinitionsITensorBaseExt/QuantumOperatorDefinitionsITensorBaseExt.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
module QuantumOperatorDefinitionsITensorBaseExt
22

33
using ITensorBase: ITensor, Index, dag, gettag, prime
4+
using NamedDimsArrays: dename
45
using QuantumOperatorDefinitions:
56
QuantumOperatorDefinitions, OpName, SiteType, StateName, has_fermion_string
67

78
function QuantumOperatorDefinitions.SiteType(r::Index)
8-
return SiteType(gettag(r, "sitetype", "Qudit"); dim=Int(length(r)))
9+
# We pass the axis of the (unnamed) Index because
10+
# the Index may have originated from a slice, in which
11+
# case the start may not be 1 (and it may not even
12+
# be a unit range).
13+
return SiteType(
14+
gettag(r, "sitetype", "Qudit"); dim=Int.(length(r)), range=only(axes(dename(r)))
15+
)
916
end
1017

1118
function QuantumOperatorDefinitions.has_fermion_string(n::String, r::Index)
@@ -14,6 +21,7 @@ end
1421

1522
function Base.AbstractArray(n::OpName, r::Index)
1623
# TODO: Define this with mapped dimnames.
24+
# Generalize beyond prime levels with codomain and domain indices.
1725
return ITensor(AbstractArray(n, SiteType(r)), (prime(r), dag(r)))
1826
end
1927

src/op.jl

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ struct OpName{Name,Params}
88
end
99
name(::OpName{Name}) where {Name} = Name
1010
params(n::OpName) = getfield(n, :params)
11-
1211
Base.getproperty(n::OpName, name::Symbol) = getfield(params(n), name)
12+
Base.get(t::OpName, name::Symbol, default) = get(params(t), name, default)
1313

1414
OpName{N}(; kwargs...) where {N} = OpName{N}((; kwargs...))
1515

@@ -54,7 +54,9 @@ end
5454
# Generic to `StateName` or `OpName`.
5555
const StateOrOpName = Union{StateName,OpName}
5656
alias(n::StateOrOpName) = n
57-
function (arrtype::Type{<:AbstractArray})(n::StateOrOpName, domain::Integer...)
57+
function (arrtype::Type{<:AbstractArray})(
58+
n::StateOrOpName, domain::Union{Integer,AbstractUnitRange}...
59+
)
5860
return arrtype(n, domain)
5961
end
6062
(arrtype::Type{<:AbstractArray})(n::StateOrOpName, ts::SiteType...) = arrtype(n, ts)
@@ -87,32 +89,50 @@ function nsites(n::StateOrOpName)
8789
return nsites(n′)
8890
end
8991

92+
function array(a::AbstractArray, ax::Tuple{Vararg{AbstractUnitRange}})
93+
return a[ax...]
94+
end
95+
9096
function op_convert(
9197
arrtype::Type{<:AbstractArray{<:Any,N}},
92-
domain::Tuple{Vararg{Integer}},
98+
domain::Tuple{Vararg{AbstractUnitRange}},
9399
a::AbstractArray{<:Any,N},
94100
) where {N}
95-
# TODO: Check the dimensions.
96-
return convert(arrtype, a)
101+
ax = (domain..., domain...)
102+
a′ = array(a, ax)
103+
return convert(arrtype, a′)
97104
end
98105
function op_convert(
99-
arrtype::Type{<:AbstractArray}, domain::Tuple{Vararg{Integer}}, a::AbstractArray
106+
arrtype::Type{<:AbstractArray}, domain::Tuple{Vararg{AbstractUnitRange}}, a::AbstractArray
100107
)
101-
# TODO: Check the dimensions.
102-
return convert(arrtype, a)
108+
ax = (domain..., domain...)
109+
a′ = array(a, ax)
110+
return convert(arrtype, a′)
103111
end
104112
function op_convert(
105-
arrtype::Type{<:AbstractArray{<:Any,N}}, domain::Tuple{Vararg{Integer}}, a::AbstractArray
113+
arrtype::Type{<:AbstractArray{<:Any,N}},
114+
domain::Tuple{Vararg{AbstractUnitRange}},
115+
a::AbstractArray,
106116
) where {N}
107-
size = (domain..., domain...)
108-
@assert length(size) == N
109-
return convert(arrtype, reshape(a, size))
117+
ax = (domain..., domain...)
118+
@assert length(ax) == N
119+
a′ = reshape(a, length.(ax))
120+
a′′ = array(a′, ax)
121+
return convert(arrtype, a′′)
110122
end
111123
function (arrtype::Type{<:AbstractArray})(n::OpName, domain::Tuple{Vararg{SiteType}})
112-
return op_convert(arrtype, length.(domain), n(domain...))
124+
domain′ = AbstractUnitRange.(domain)
125+
return op_convert(arrtype, domain′, n(domain...))
126+
end
127+
128+
function (arrtype::Type{<:AbstractArray})(
129+
n::OpName, domain::Tuple{Vararg{AbstractUnitRange}}
130+
)
131+
# TODO: Make `(::OpName)(domain...)` constructor process more general inputs.
132+
return op_convert(arrtype, domain, n(Int.(length.(domain))...))
113133
end
114134
function (arrtype::Type{<:AbstractArray})(n::OpName, domain::Tuple{Vararg{Integer}})
115-
return op_convert(arrtype, domain, n(Int.(domain)...))
135+
return arrtype(n, Base.oneto.(domain))
116136
end
117137

118138
function op(arrtype::Type{<:AbstractArray}, n::String, domain...; kwargs...)

src/sitetype.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@ struct SiteType{T,Params}
44
return new{N,typeof(params)}(params)
55
end
66
end
7+
value(::SiteType{T}) where {T} = T
78
params(t::SiteType) = getfield(t, :params)
89
Base.getproperty(t::SiteType, name::Symbol) = getfield(params(t), name)
10+
Base.get(t::SiteType, name::Symbol, default) = get(params(t), name, default)
911

1012
SiteType{N}(; kwargs...) where {N} = SiteType{N}((; kwargs...))
1113

1214
SiteType(s::AbstractString; kwargs...) = SiteType{Symbol(s)}(; kwargs...)
1315
SiteType(i::Integer; kwargs...) = SiteType{Symbol(i)}(; kwargs...)
14-
value(::SiteType{T}) where {T} = T
1516
macro SiteType_str(s)
1617
return SiteType{Symbol(s)}
1718
end
@@ -26,7 +27,11 @@ function Base.length(t::SiteType)
2627
end
2728
return length(t′)
2829
end
29-
Base.AbstractUnitRange(t::SiteType) = Base.OneTo(length(t))
30+
function Base.AbstractUnitRange(t::SiteType)
31+
# This logic allows specifying a range with extra properties,
32+
# like ones with symmetry sectors.
33+
return get(t, :range, Base.OneTo(length(t)))
34+
end
3035
Base.size(t::SiteType) = (length(t),)
3136
Base.size(t::SiteType, dim::Integer) = size(t)[dim]
3237
Base.axes(t::SiteType) = (AbstractUnitRange(t),)

src/state.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@ struct StateName{Name,Params}
66
return new{N,typeof(params)}(params)
77
end
88
end
9+
name(::StateName{N}) where {N} = N
910
params(n::StateName) = getfield(n, :params)
1011
Base.getproperty(n::StateName, name::Symbol) = getfield(params(n), name)
12+
Base.get(t::StateName, name::Symbol, default) = get(params(t), name, default)
1113

1214
StateName{N}(; kwargs...) where {N} = StateName{N}((; kwargs...))
1315

1416
StateName(s::AbstractString; kwargs...) = StateName{Symbol(s)}(; kwargs...)
1517
StateName(s::Symbol; kwargs...) = StateName{s}(; kwargs...)
16-
name(::StateName{N}) where {N} = N
1718
macro StateName_str(s)
1819
return StateName{Symbol(s)}
1920
end

0 commit comments

Comments
 (0)