Skip to content

Commit 113a1fe

Browse files
committed
fix tests
1 parent c373729 commit 113a1fe

File tree

9 files changed

+35
-49
lines changed

9 files changed

+35
-49
lines changed

NDTensors/src/lib/GradedAxes/ext/GradedAxesSectorsExt/Project.toml

Lines changed: 0 additions & 2 deletions
This file was deleted.

NDTensors/src/lib/GradedAxes/ext/GradedAxesSectorsExt/src/GradedAxesSectorsExt.jl

Lines changed: 0 additions & 9 deletions
This file was deleted.

NDTensors/src/lib/GradedAxes/ext/GradedAxesSectorsExt/test/Project.toml

Lines changed: 0 additions & 3 deletions
This file was deleted.

NDTensors/src/lib/GradedAxes/ext/GradedAxesSectorsExt/test/runtests.jl

Lines changed: 0 additions & 15 deletions
This file was deleted.
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
module GradedAxes
22
include("blockedunitrange.jl")
33
include("gradedunitrange.jl")
4-
include("gradedunitrangedual.jl")
54
include("dual.jl")
5+
include("gradedunitrangedual.jl")
66
include("unitrangedual.jl")
7-
include("../ext/GradedAxesSectorsExt/src/GradedAxesSectorsExt.jl")
7+
include("fusion.jl")
88
end

NDTensors/src/lib/GradedAxes/src/dual.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,5 @@ using NDTensors.LabelledNumbers:
66
label_dual(x) = label_dual(LabelledStyle(x), x)
77
label_dual(::NotLabelled, x) = x
88
label_dual(::IsLabelled, x) = labelled(unlabel(x), dual(label(x)))
9+
10+
flip(g::AbstractGradedUnitRange) = dual(gradedrange(label_dual.(blocklengths(g))))

NDTensors/src/lib/GradedAxes/src/fusion.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ OneToOne() = OneToOne{Bool}()
66
Base.first(a::OneToOne) = one(eltype(a))
77
Base.last(a::OneToOne) = one(eltype(a))
88

9+
gradedisequal(::AbstractUnitRange, ::OneToOne) = false
10+
gradedisequal(::OneToOne, ::AbstractUnitRange) = false
11+
gradedisequal(::OneToOne, ::OneToOne) = true
12+
913
# https://github.com/ITensor/ITensors.jl/blob/v0.3.57/NDTensors/src/lib/GradedAxes/src/tensor_product.jl
1014
# https://en.wikipedia.org/wiki/Tensor_product
1115
# https://github.com/KeitaNakamura/Tensorial.jl
@@ -18,7 +22,7 @@ function tensor_product(
1822
return foldl(tensor_product, (a1, a2, a3, a_rest...))
1923
end
2024

21-
function tensor_product(a1::AbstractUnitRange, a2::AbstractUnitRange)
25+
function tensor_product(::AbstractUnitRange, ::AbstractUnitRange)
2226
return error("Not implemented yet.")
2327
end
2428

@@ -34,7 +38,7 @@ function tensor_product(a1::AbstractBlockedUnitRange, ::OneToOne)
3438
return a1
3539
end
3640

37-
function tensor_product(a1::OneToOne, a2::OneToOne)
41+
function tensor_product(::OneToOne, ::OneToOne)
3842
return OneToOne()
3943
end
4044

@@ -66,18 +70,20 @@ function fuse_blocklengths(x::LabelledInteger, y::LabelledInteger)
6670
return labelled(unlabel(x) * unlabel(y), fuse_labels(label(x), label(y)))
6771
end
6872

73+
flatten_maybe_nested(v::Vector{<:Integer}) = v
74+
flatten_maybe_nested(v::Vector{<:AbstractGradedUnitRange}) = reduce(vcat, blocklengths.(v))
75+
6976
using BlockArrays: blockedrange, blocks
7077
function tensor_product(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange)
71-
blocklengths = map(vec(collect(Iterators.product(blocks(a1), blocks(a2))))) do x
72-
return mapreduce(length, fuse_blocklengths, x)
73-
end
78+
maybe_nested = map(
79+
it -> mapreduce(length, fuse_blocklengths, it),
80+
Iterators.flatten((Iterators.product(blocks(a1), blocks(a2)),)),
81+
)
82+
blocklengths = flatten_maybe_nested(maybe_nested)
7483
return blockedrange(blocklengths)
7584
end
7685

7786
function blocksortperm(a::AbstractBlockedUnitRange)
78-
# TODO: Figure out how to deal with dual sectors.
79-
# TODO: `rev=isdual(a)` may not be correct for symmetries beyond `U(1)`.
80-
## return Block.(sortperm(nondual_sectors(a); rev=isdual(a)))
8187
return Block.(sortperm(blocklabels(a)))
8288
end
8389

@@ -101,12 +107,6 @@ end
101107
# Get the permutation for sorting, then group by common elements.
102108
# groupsortperm([2, 1, 2, 3]) == [[2], [1, 3], [4]]
103109
function blockmergesortperm(a::AbstractBlockedUnitRange)
104-
# If it is dual, reverse the sorting so the sectors
105-
# end up sorted in the same way whether or not the space
106-
# is dual.
107-
# TODO: Figure out how to deal with dual sectors.
108-
# TODO: `rev=isdual(a)` may not be correct for symmetries beyond `U(1)`.
109-
## return Block.(groupsortperm(nondual_sectors(a); rev=isdual(a)))
110110
return Block.(groupsortperm(blocklabels(a)))
111111
end
112112

NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ using BlockArrays:
1212
blockedrange,
1313
BlockIndexRange,
1414
blockfirsts,
15-
blocklasts,
15+
blockisequal,
1616
blocklength,
1717
blocklengths,
1818
findblock,
@@ -37,6 +37,11 @@ function Base.OrdinalRange{T,T}(a::GradedOneTo{<:LabelledInteger{T}}) where {T}
3737
return unlabel_blocks(a)
3838
end
3939

40+
# == is just a range comparison that ignores labels. Need dedicated function to check equality.
41+
function gradedisequal(a1::AbstractUnitRange, a2::AbstractUnitRange)
42+
return blockisequal(a1, a2) && (blocklabels(a1) == blocklabels(a2))
43+
end
44+
4045
# This is only needed in certain Julia versions below 1.10
4146
# (for example Julia 1.6).
4247
# TODO: Delete this once we drop Julia 1.6 support.

NDTensors/src/lib/GradedAxes/test/test_dual.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,17 @@ end
9191
ad = dual(a)
9292
@test ad isa GradedUnitRangeDual
9393
@test eltype(ad) == LabelledInteger{Int,U1}
94-
@test dual(ad) == a
95-
@test nondual(ad) == a
96-
@test nondual(a) == a
94+
95+
@test gradedisequal(dual(ad), a)
96+
@test gradedisequal(nondual(ad), a)
97+
@test gradedisequal(nondual(a), a)
98+
@test gradedisequal(ad, ad)
99+
@test !gradedisequal(a, ad)
100+
@test !gradedisequal(ad, a)
101+
102+
@test isdual(ad)
103+
@test !isdual(a)
104+
97105
@test blockfirsts(ad) == [labelled(1, U1(0)), labelled(3, U1(-1))]
98106
@test blocklasts(ad) == [labelled(2, U1(0)), labelled(5, U1(-1))]
99107
@test blocklength(ad) == 2

0 commit comments

Comments
 (0)