Skip to content

Commit 833cc47

Browse files
committed
Support for specifying symmetries of axes
1 parent 08db501 commit 833cc47

File tree

4 files changed

+150
-7
lines changed

4 files changed

+150
-7
lines changed

Project.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,24 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
99

1010
[weakdeps]
11+
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
12+
GradedUnitRanges = "e2de450a-8a67-46c7-b59c-01d5a3d041c5"
13+
LabelledNumbers = "f856a3a6-4152-4ec4-b2a7-02c1a55d7993"
1114
ITensorBase = "4795dd04-0d67-49bb-8f44-b89c448a1dc7"
1215
NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
16+
SymmetrySectors = "f8a8ad64-adbc-4fce-92f7-ffe2bb36a86e"
1317

1418
[extensions]
1519
QuantumOperatorDefinitionsITensorBaseExt = ["ITensorBase", "NamedDimsArrays"]
20+
QuantumOperatorDefinitionsSymmetrySectorsExt = ["BlockArrays", "GradedUnitRanges", "LabelledNumbers", "SymmetrySectors"]
1621

1722
[compat]
23+
BlockArrays = "1.3.0"
24+
GradedUnitRanges = "0.1.2"
1825
ITensorBase = "0.1.10"
26+
LabelledNumbers = "0.1.0"
1927
LinearAlgebra = "1.10"
2028
NamedDimsArrays = "0.4.0"
2129
Random = "1.10"
30+
SymmetrySectors = "0.1.3"
2231
julia = "1.10"

ext/QuantumOperatorDefinitionsITensorBaseExt/QuantumOperatorDefinitionsITensorBaseExt.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module QuantumOperatorDefinitionsITensorBaseExt
22

3-
using ITensorBase: ITensor, Index, dag, gettag, prime
3+
using ITensorBase: ITensorBase, ITensor, Index, dag, gettag, prime
44
using NamedDimsArrays: dename
55
using QuantumOperatorDefinitions:
66
QuantumOperatorDefinitions, OpName, SiteType, StateName, has_fermion_string
@@ -15,6 +15,10 @@ function QuantumOperatorDefinitions.SiteType(r::Index)
1515
)
1616
end
1717

18+
function ITensorBase.Index(t::SiteType; kwargs...)
19+
return Index(AbstractUnitRange(t); kwargs...)
20+
end
21+
1822
function QuantumOperatorDefinitions.has_fermion_string(n::String, r::Index)
1923
return has_fermion_string(OpName(n), SiteType(r))
2024
end
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
module QuantumOperatorDefinitionsSymmetrySectorsExt
2+
3+
using BlockArrays: blocklasts, blocklengths
4+
using GradedUnitRanges: GradedOneTo, gradedrange
5+
using LabelledNumbers: label, labelled, unlabel
6+
using QuantumOperatorDefinitions:
7+
QuantumOperatorDefinitions, @SiteType_str, @SymmetryType_str, SiteType, SymmetryType, name
8+
using SymmetrySectors: ×, SectorProduct, U1, Z
9+
10+
sortedunion(a, b) = sort(union(a, b))
11+
function QuantumOperatorDefinitions.combine_axes(a1::GradedOneTo, a2::GradedOneTo)
12+
return gradedrange(
13+
map(blocklengths(a1), blocklengths(a2)) do s1, s2
14+
l1 = unlabel(s1)
15+
l2 = unlabel(s2)
16+
@assert l1 == l2
17+
labelled(l1, label(s1) × label(s2))
18+
end,
19+
)
20+
end
21+
QuantumOperatorDefinitions.combine_axes(a::GradedOneTo, b::Base.OneTo) = a
22+
QuantumOperatorDefinitions.combine_axes(a::Base.OneTo, b::GradedOneTo) = b
23+
24+
function Base.AbstractUnitRange(::SymmetryType"N", t::SiteType)
25+
return gradedrange(map(i -> SectorProduct((; N=U1(i - 1))) => 1, 1:length(t)))
26+
end
27+
function Base.AbstractUnitRange(::SymmetryType"Sz", t::SiteType)
28+
return gradedrange(map(i -> SectorProduct((; Sz=U1(i - 1))) => 1, 1:length(t)))
29+
end
30+
function Base.AbstractUnitRange(::SymmetryType"Sz↑", t::SiteType)
31+
return AbstractUnitRange(SymmetryType"Sz"(), t)
32+
end
33+
function Base.AbstractUnitRange(::SymmetryType"Sz↓", t::SiteType)
34+
return gradedrange(map(i -> SectorProduct((; Sz=U1(-(i - 1)))) => 1, 1:length(t)))
35+
end
36+
37+
function sector(symmetrytype::SymmetryType, sec)
38+
sectorname = Symbol(get(symmetrytype, :name, name(symmetrytype)))
39+
return SectorProduct(NamedTuple{(sectorname,)}((sec,)))
40+
end
41+
42+
function Base.AbstractUnitRange(s::SymmetryType"Nf", t::SiteType"Fermion")
43+
return gradedrange([sector(s, U1(0)) => 1, sector(s, U1(1)) => 1])
44+
end
45+
# TODO: Write in terms of `SymmetryType"Nf"` definition.
46+
function Base.AbstractUnitRange(s::SymmetryType"NfParity", t::SiteType"Fermion")
47+
return gradedrange([sector(s, Z{2}(0)) => 1, sector(s, Z{2}(1)) => 1])
48+
end
49+
function Base.AbstractUnitRange(s::SymmetryType"Sz", t::SiteType"Fermion")
50+
return gradedrange([sector(s, U1(0)) => 1, sector(s, U1(1)) => 1])
51+
end
52+
function Base.AbstractUnitRange(s::SymmetryType"Sz↑", t::SiteType"Fermion")
53+
return gradedrange([sector(s, U1(0)) => 1, sector(s, U1(1)) => 1])
54+
end
55+
function Base.AbstractUnitRange(s::SymmetryType"Sz↓", t::SiteType"Fermion")
56+
return gradedrange([sector(s, U1(0)) => 1, sector(s, U1(-1)) => 1])
57+
end
58+
59+
# TODO: Write in terms of `SiteType"Fermion"` definitions.
60+
function Base.AbstractUnitRange(s::SymmetryType"Nf", t::SiteType"Electron")
61+
return gradedrange([
62+
sector(s, U1(0)) => 1,
63+
sector(s, U1(1)) => 1,
64+
sector(s, U1(1)) => 1,
65+
sector(s, U1(2)) => 1,
66+
])
67+
end
68+
# TODO: Write in terms of `SymmetryType"Nf"` definition.
69+
function Base.AbstractUnitRange(s::SymmetryType"NfParity", t::SiteType"Electron")
70+
return gradedrange([
71+
sector(s, Z{2}(0)) => 1,
72+
sector(s, Z{2}(1)) => 1,
73+
sector(s, Z{2}(1)) => 1,
74+
sector(s, Z{2}(0)) => 1,
75+
])
76+
end
77+
function Base.AbstractUnitRange(s::SymmetryType"Sz", t::SiteType"Electron")
78+
return gradedrange([
79+
sector(s, U1(0)) => 1,
80+
sector(s, U1(1)) => 1,
81+
sector(s, U1(-1)) => 1,
82+
sector(s, U1(0)) => 1,
83+
])
84+
end
85+
86+
end

src/sitetype.jl

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,33 +4,77 @@ struct SiteType{T,Params}
44
return new{N,typeof(params)}(params)
55
end
66
end
7-
value(::SiteType{T}) where {T} = T
7+
name(::SiteType{T}) where {T} = T
88
params(t::SiteType) = getfield(t, :params)
99
Base.getproperty(t::SiteType, name::Symbol) = getfield(params(t), name)
1010
Base.get(t::SiteType, name::Symbol, default) = get(params(t), name, default)
11-
11+
Base.haskey(t::SiteType, name::Symbol) = haskey(params(t), name)
1212
SiteType{N}(; kwargs...) where {N} = SiteType{N}((; kwargs...))
13-
1413
SiteType(s::AbstractString; kwargs...) = SiteType{Symbol(s)}(; kwargs...)
1514
SiteType(i::Integer; kwargs...) = SiteType{Symbol(i)}(; kwargs...)
1615
macro SiteType_str(s)
17-
return SiteType{Symbol(s)}
16+
return :(SiteType{$(Expr(:quote, Symbol(s)))})
1817
end
1918

2019
alias(t::SiteType) = t
2120
alias(i::Integer) = i
2221

22+
# Like `Base.Broadcast.axistype` (https://github.com/JuliaLang/julia/blob/v1.11.3/base/broadcast.jl#L536-L538)
23+
# and `BlockArrays.combine_blockaxes` (https://github.com/JuliaArrays/BlockArrays.jl/blob/v1.3.0/src/blockbroadcast.jl#L37-L38).
24+
combine_axes(a::T, b::T) where {T} = a
25+
combine_axes(a::Base.OneTo, b::Base.OneTo) = Base.OneTo{Int}(a)
26+
function combine_axes(a, b)
27+
return UnitRange{Int}(a)
28+
end
29+
combine_axes(a) = a
30+
combine_axes(a, b, rest...) = combine_axes(combine_axes(a, b), rest...)
31+
32+
struct SymmetryType{T,Params}
33+
params::Params
34+
function SymmetryType{N}(params::NamedTuple) where {N}
35+
return new{N,typeof(params)}(params)
36+
end
37+
end
38+
name(::SymmetryType{T}) where {T} = T
39+
params(t::SymmetryType) = getfield(t, :params)
40+
Base.getproperty(t::SymmetryType, name::Symbol) = getfield(params(t), name)
41+
Base.get(t::SymmetryType, name::Symbol, default) = get(params(t), name, default)
42+
Base.haskey(t::SymmetryType, name::Symbol) = haskey(params(t), name)
43+
SymmetryType{N}(; kwargs...) where {N} = SymmetryType{N}((; kwargs...))
44+
SymmetryType(s::AbstractString; kwargs...) = SymmetryType{Symbol(s)}(; kwargs...)
45+
function SymmetryType(s::Pair{<:AbstractString,<:AbstractString}; kwargs...)
46+
return SymmetryType(first(s); kwargs..., name=last(s))
47+
end
48+
function SymmetryType(s::Pair{<:AbstractString,<:NamedTuple}; kwargs...)
49+
return SymmetryType(first(s); kwargs..., last(s)...)
50+
end
51+
macro SymmetryType_str(s)
52+
return :(SymmetryType{$(Expr(:quote, Symbol(s)))})
53+
end
54+
55+
function Base.AbstractUnitRange(symmetry::SymmetryType, t::SiteType)
56+
return error("Not implemented.")
57+
end
58+
function Base.AbstractUnitRange(symmetry::SymmetryType"Trivial", t::SiteType)
59+
return Base.OneTo(length(t))
60+
end
61+
2362
function Base.length(t::SiteType)
2463
t′ = alias(t)
2564
if t == t′
26-
return t.length
65+
return t.dim
2766
end
2867
return length(t′)
2968
end
3069
function Base.AbstractUnitRange(t::SiteType)
3170
# This logic allows specifying a range with extra properties,
3271
# like ones with symmetry sectors.
33-
return get(t, :range, Base.OneTo(length(t)))
72+
haskey(t, :range) && return t.range
73+
if haskey(t, :symmetries)
74+
rs = map(symmetry -> AbstractUnitRange(SymmetryType(symmetry), t), t.symmetries)
75+
return combine_axes(Base.OneTo(length(t)), rs...)
76+
end
77+
return Base.OneTo(length(t))
3478
end
3579
Base.size(t::SiteType) = (length(t),)
3680
Base.size(t::SiteType, dim::Integer) = size(t)[dim]

0 commit comments

Comments
 (0)