Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@eval module $(gensym())
using Compat: Returns
using Test: @test, @testset, @test_broken
using Test: @test, @testset
using BlockArrays:
AbstractBlockArray, Block, BlockedOneTo, blockedrange, blocklengths, blocksize
using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored
Expand Down Expand Up @@ -217,10 +217,10 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test size(a[I, I]) == (1, 1)
@test isdual(axes(a[I, :], 2))
@test isdual(axes(a[:, I], 1))
@test_broken isdual(axes(a[I, :], 1))
@test_broken isdual(axes(a[:, I], 2))
@test_broken isdual(axes(a[I, I], 1))
@test_broken isdual(axes(a[I, I], 2))
@test isdual(axes(a[I, :], 1))
@test isdual(axes(a[:, I], 2))
@test isdual(axes(a[I, I], 1))
@test isdual(axes(a[I, I], 2))
end

@testset "dual GradedUnitRange" begin
Expand All @@ -243,10 +243,10 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test size(a[I, I]) == (1, 1)
@test isdual(axes(a[I, :], 2))
@test isdual(axes(a[:, I], 1))
@test_broken isdual(axes(a[I, :], 1))
@test_broken isdual(axes(a[:, I], 2))
@test_broken isdual(axes(a[I, I], 1))
@test_broken isdual(axes(a[I, I], 2))
@test isdual(axes(a[I, :], 1))
@test isdual(axes(a[:, I], 2))
@test isdual(axes(a[I, I], 1))
@test isdual(axes(a[I, I], 2))
end

@testset "dual BlockedUnitRange" begin # self dual
Expand Down
1 change: 1 addition & 0 deletions NDTensors/src/lib/GradedAxes/src/GradedAxes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module GradedAxes
include("blockedunitrange.jl")
include("gradedunitrange.jl")
include("dual.jl")
include("labelledunitrangedual.jl")
include("gradedunitrangedual.jl")
include("onetoone.jl")
include("fusion.jl")
Expand Down
5 changes: 3 additions & 2 deletions NDTensors/src/lib/GradedAxes/src/dual.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# default behavior: self-dual
dual(r::AbstractUnitRange) = r
# default behavior: any object is self-dual
dual(x) = x
nondual(r::AbstractUnitRange) = r
isdual(::AbstractUnitRange) = false

Expand All @@ -11,4 +11,5 @@ label_dual(x) = label_dual(LabelledStyle(x), x)
label_dual(::NotLabelled, x) = x
label_dual(::IsLabelled, x) = labelled(unlabel(x), dual(label(x)))

flip(a::AbstractUnitRange) = dual(label_dual(a))
flip(g::AbstractGradedUnitRange) = dual(gradedrange(label_dual.(blocklengths(g))))
1 change: 1 addition & 0 deletions NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ end
# == is just a range comparison that ignores labels. Need dedicated function to check equality.
struct NoLabel end
blocklabels(r::AbstractUnitRange) = Fill(NoLabel(), blocklength(r))
blocklabels(la::LabelledUnitRange) = [label(la)]

function LabelledNumbers.labelled_isequal(a1::AbstractUnitRange, a2::AbstractUnitRange)
return blockisequal(a1, a2) && (blocklabels(a1) == blocklabels(a2))
Expand Down
53 changes: 44 additions & 9 deletions NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,15 @@ function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::Integer)
end

function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::Block{1})
return label_dual(getindex(nondual(a), indices))
return dual(getindex(nondual(a), indices))
end

function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::BlockRange)
return label_dual(getindex(nondual(a), indices))
return dual(getindex(nondual(a), indices))
end

function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::BlockIndexRange)
return dual(nondual(a)[indices])
end

# fix ambiguity
Expand All @@ -49,20 +53,51 @@ function BlockArrays.blocklengths(a::GradedUnitRangeDual)
return dual.(blocklengths(nondual(a)))
end

function gradedunitrangedual_getindices_blocks(a::GradedUnitRangeDual, indices)
# TODO: Move this to a `BlockArraysExtensions` library.
function blockedunitrange_getindices(
a::GradedUnitRangeDual, indices::Vector{<:BlockIndexRange{1}}
)
a_indices = getindex(nondual(a), indices)
return mortar([label_dual(b) for b in blocks(a_indices)])
v = mortar(dual.(blocks(a_indices)))
# flip v to stay consistent with other cases where axes(v) are used
return flip_blockvector(v)
end

# TODO: Move this to a `BlockArraysExtensions` library.
function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::Vector{<:Block{1}})
return gradedunitrangedual_getindices_blocks(a, indices)
function blockedunitrange_getindices(
a::GradedUnitRangeDual,
indices::BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}},
)
v = mortar(map(b -> a[b], blocks(indices)))
# GradedOneTo appears in mortar
# flip v axis to preserve dual information
# axes(v) will appear in axes(view(::BlockSparseArray, [Block(1)[1:1]]))
return flip_blockvector(v)
end

function blockedunitrange_getindices(
a::GradedUnitRangeDual, indices::Vector{<:BlockIndexRange{1}}
a::GradedUnitRangeDual, indices::AbstractVector{<:Union{Block{1},BlockIndexRange{1}}}
)
return gradedunitrangedual_getindices_blocks(a, indices)
# Without converting `indices` to `Vector`,
# mapping `indices` outputs a `BlockVector`
# which is harder to reason about.
vblocks = map(index -> a[index], Vector(indices))
# We pass `length.(blocks)` to `mortar` in order
# to pass block labels to the axes of the output,
# if they exist. This makes it so that
# `only(axes(a[indices])) isa `GradedUnitRange`
# if `a isa `GradedUnitRange`, for example.

v = mortar(vblocks, length.(vblocks))
# GradedOneTo appears in mortar
# flip v axis to preserve dual information
# axes(v) will appear in axes(view(::BlockSparseArray, [Block(1)]))
return flip_blockvector(v)
end

function flip_blockvector(v::BlockVector)
block_axes = flip.(axes(v))
flipped = mortar(vec.(blocks(v)), block_axes)
return flipped
end

Base.axes(a::GradedUnitRangeDual) = axes(nondual(a))
Expand Down
49 changes: 49 additions & 0 deletions NDTensors/src/lib/GradedAxes/src/labelledunitrangedual.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# LabelledUnitRangeDual is obtained by slicing a GradedUnitRangeDual with a block

using ..LabelledNumbers: LabelledNumbers, label, labelled, unlabel

struct LabelledUnitRangeDual{T,NondualUnitRange<:AbstractUnitRange{T}} <:
AbstractUnitRange{T}
nondual_unitrange::NondualUnitRange
end

dual(a::LabelledUnitRange) = LabelledUnitRangeDual(a)
nondual(a::LabelledUnitRangeDual) = a.nondual_unitrange
dual(a::LabelledUnitRangeDual) = nondual(a)
label_dual(::IsLabelled, a::LabelledUnitRangeDual) = dual(label_dual(nondual(a)))
isdual(::LabelledUnitRangeDual) = true
blocklabels(la::LabelledUnitRangeDual) = [label(la)]

LabelledNumbers.label(a::LabelledUnitRangeDual) = dual(label(nondual(a)))
LabelledNumbers.unlabel(a::LabelledUnitRangeDual) = unlabel(nondual(a))
LabelledNumbers.LabelledStyle(::LabelledUnitRangeDual) = IsLabelled()

for f in [:first, :getindex, :last, :length, :step]
@eval Base.$f(a::LabelledUnitRangeDual, args...) =
labelled($f(unlabel(a), args...), label(a))
end

# fix ambiguities
Base.getindex(a::LabelledUnitRangeDual, i::Integer) = dual(nondual(a)[i])
function Base.getindex(a::LabelledUnitRangeDual, indices::AbstractUnitRange{<:Integer})
return dual(nondual(a)[indices])
end

function Base.iterate(a::LabelledUnitRangeDual, i)
i == last(a) && return nothing
next = convert(eltype(a), labelled(i + step(a), label(a)))
return (next, next)
end

function Base.show(io::IO, ::MIME"text/plain", a::LabelledUnitRangeDual)
println(io, typeof(a))
return print(io, label(a), " => ", unlabel(a))
end

function Base.show(io::IO, a::LabelledUnitRangeDual)
return print(io, nameof(typeof(a)), " ", label(a), " => ", unlabel(a))
end

function Base.AbstractUnitRange{T}(a::LabelledUnitRangeDual) where {T}
return AbstractUnitRange{T}(nondual(a))
end
104 changes: 100 additions & 4 deletions NDTensors/src/lib/GradedAxes/test/test_dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ using NDTensors.GradedAxes:
AbstractGradedUnitRange,
GradedAxes,
GradedUnitRangeDual,
LabelledUnitRangeDual,
OneToOne,
blocklabels,
blockmergesortperm,
Expand All @@ -27,7 +28,8 @@ using NDTensors.GradedAxes:
gradedrange,
isdual,
nondual
using NDTensors.LabelledNumbers: LabelledInteger, label, labelled, labelled_isequal
using NDTensors.LabelledNumbers:
LabelledInteger, LabelledUnitRange, label, label_type, labelled, labelled_isequal, unlabel
using Test: @test, @test_broken, @testset
struct U1
n::Int
Expand Down Expand Up @@ -58,6 +60,92 @@ Base.isless(c1::U1, c2::U1) = c1.n < c2.n
@test blockisequal(ad, a)
end

@testset "LabelledUnitRangeDual" begin
la = labelled(1:2, U1(1))
@test la isa LabelledUnitRange
@test label(la) == U1(1)
@test blocklabels(la) == [U1(1)]
@test unlabel(la) == 1:2
@test la == 1:2
@test !isdual(la)
@test labelled_isequal(la, la)
@test space_isequal(la, la)
@test label_type(la) == U1

@test iterate(la) == (1, 1)
@test iterate(la) == (1, 1)
@test iterate(la, 1) == (2, 2)
@test isnothing(iterate(la, 2))

lad = dual(la)
@test lad isa LabelledUnitRangeDual
@test label(lad) == U1(-1)
@test blocklabels(lad) == [U1(-1)]
@test unlabel(lad) == 1:2
@test lad == 1:2
@test labelled_isequal(lad, lad)
@test space_isequal(lad, lad)
@test !labelled_isequal(la, lad)
@test !space_isequal(la, lad)
@test isdual(lad)
@test nondual(lad) === la
@test dual(lad) === la
@test label_type(lad) == U1

@test iterate(lad) == (1, 1)
@test iterate(lad) == (1, 1)
@test iterate(lad, 1) == (2, 2)
@test isnothing(iterate(lad, 2))

lad2 = lad[1:1]
@test lad2 isa LabelledUnitRangeDual
@test label(lad2) == U1(-1)
@test unlabel(lad2) == 1:1

laf = flip(la)
@test laf isa LabelledUnitRangeDual
@test label(laf) == U1(1)
@test unlabel(laf) == 1:2
@test labelled_isequal(la, laf)
@test !space_isequal(la, laf)

ladf = flip(dual(la))
@test ladf isa LabelledUnitRange
@test label(ladf) == U1(-1)
@test unlabel(ladf) == 1:2

lafd = dual(flip(la))
@test lafd isa LabelledUnitRange
@test label(lafd) == U1(-1)
@test unlabel(lafd) == 1:2

# check default behavior for objects without dual
la = labelled(1:2, 'x')
lad = dual(la)
@test lad isa LabelledUnitRangeDual
@test label(lad) == 'x'
@test blocklabels(lad) == ['x']
@test unlabel(lad) == 1:2
@test lad == 1:2
@test labelled_isequal(lad, lad)
@test space_isequal(lad, lad)
@test labelled_isequal(la, lad)
@test !space_isequal(la, lad)
@test isdual(lad)
@test nondual(lad) === la
@test dual(lad) === la

laf = flip(la)
@test laf isa LabelledUnitRangeDual
@test label(laf) == 'x'
@test unlabel(laf) == 1:2

ladf = flip(lad)
@test ladf isa LabelledUnitRange
@test label(ladf) == 'x'
@test unlabel(ladf) == 1:2
end

@testset "GradedUnitRangeDual" begin
for a in
[gradedrange([U1(0) => 2, U1(1) => 3]), gradedrange([U1(0) => 2, U1(1) => 3])[1:5]]
Expand Down Expand Up @@ -124,13 +212,21 @@ end
@test blockmergesortperm(a) == [Block(1), Block(2)]
@test blockmergesortperm(ad) == [Block(1), Block(2)]

@test_broken isdual(ad[Block(1)])
@test_broken isdual(ad[Block(1)[1:1]])
@test isdual(ad[Block(1)])
@test isdual(ad[Block(1)[1:1]])
@test ad[Block(1)] isa LabelledUnitRangeDual
@test ad[Block(1)[1:1]] isa LabelledUnitRangeDual
@test label(ad[Block(2)]) == U1(-1)
@test label(ad[Block(2)[1:1]]) == U1(-1)

I = mortar([Block(2)[1:1]])
g = ad[I]
@test length(g) == 1
@test label(first(g)) == U1(-1)
@test_broken isdual(g[Block(1)])
@test isdual(g[Block(1)])

@test isdual(axes(ad[[Block(1)]], 1)) # used in view(::BlockSparseVector, [Block(1)])
@test isdual(axes(ad[mortar([Block(1)[1:1]])], 1)) # used in view(::BlockSparseVector, [Block(1)[1:1]])
end
end

Expand Down
9 changes: 9 additions & 0 deletions NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,12 @@ function Base.iterate(a::LabelledUnitRange, i)
next = convert(eltype(a), labelled(i + step(a), label(a)))
return (next, next)
end

function Base.show(io::IO, ::MIME"text/plain", a::LabelledUnitRange)
println(io, typeof(a))
return print(io, label(a), " => ", unlabel(a))
end

function Base.show(io::IO, a::LabelledUnitRange)
return print(io, nameof(typeof(a)), " ", label(a), " => ", unlabel(a))
end
Loading