Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "GradedArrays"
uuid = "bc96ca6e-b7c8-4bb6-888e-c93f838762c2"
version = "0.6.19"
version = "0.6.20"
authors = ["ITensor developers <support@itensor.org> and contributors"]

[workspace]
Expand Down
7 changes: 3 additions & 4 deletions src/sectorarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,9 @@ function Base.show(
io::IO, g::SectorUnitRange{I, RB, R}
) where {I <: SectorRange, RB <: AbstractUnitRange{Int}, R <: AbstractUnitRange{Int}}
a, b = kroneckerfactors(g)
if b isa Base.OneTo
print(io, "sectorrange(", a, ", ", unproduct(g), ")")
else
print(io, "sectorrange(", a, " => ", b, ", ", unproduct(g), ")")
print(io, "sectorrange(", a, ", ", b, ")")
if !isone(first(g))
print(io, " .+ ", first(g) - 1)
end
return nothing
end
Expand Down
140 changes: 122 additions & 18 deletions src/tensoralgebra.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
using BlockArrays: blocks, eachblockaxes1
using BlockArrays: BlockIndexRange, blocks, eachblockaxes1
using BlockSparseArrays: BlockSparseArray, blockrange, blockreshape
using GradedArrays: GradedArray, GradedUnitRange, SectorRange, flip, invblockperm,
sectormergesortperm, sectorsortperm, trivial, unmerged_tensor_product, ×
using GradedArrays: GradedArray, GradedUnitRange, SectorRange, flip, gradedrange,
invblockperm, sectormergesortperm, sectors, sectorsortperm, trivial,
unmerged_tensor_product, ×
using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, BlockedTuple, FusionStyle,
ReshapeFusion, matricize, matricize_axes, tensor_product_axis, trivialbiperm,
tuplemortar, unmatricize
Expand Down Expand Up @@ -118,27 +119,130 @@ function TensorAlgebra.unmatricize(
return a
end

# First, fuse axes to get `sectormergesortperm`.
# Then unpermute the blocks.
fused_axes = matricize_axes(BlockReshapeFusion(), m, codomain_axes, domain_axes)

blockperms = sectorsortperm.(fused_axes)
sorted_axes = map((r, I) -> only(axes(r[I])), fused_axes, blockperms)

# TODO: This is doing extra copies of the blocks,
# use `@view a[axes_prod...]` instead.
# That will require implementing some reindexing logic
# for this combination of slicing.
m_unblocked = m[sorted_axes...]
m_blockpermed = m_unblocked[invblockperm.(blockperms)...]
return unmatricize(FusionStyle(BlockSparseArray), m_blockpermed, blocked_axes)
J = map(invblockmergeperm, fused_axes, blockperms, axes(m))
return unmatricize(FusionStyle(BlockSparseArray), m[J...], blocked_axes)
end

# Sort the blocks by sector and then merge the common sectors.
function sectormergesort(a::AbstractArray)
# TODO: fix this, no clue why broken and no clue how to fix
return a

I = sectormergesortperm.(axes(a))
return a[I...]
end

# Returns a Vector{BlockIndexRange{1}} mapping each block of fine_ax (in original order)
# to its position (block + subrange) within the merged axis merged_ax, given the block
# permutation blockperm used to sort and merge fine_ax into merged_ax.
# Requires that blocks of fine_ax subdivide blocks of merged_ax.
function invblockmergeperm(fine_ax, blockperm, merged_ax)
n = length(blockperm)
bir_type = Base.promote_op(getindex, Block{1, Int}, UnitRange{Int})
J = Vector{bir_type}(undef, n)
j = 1
offset = 0
for k′ in 1:n
k = Int(blockperm[k′])
size_k = length(fine_ax[Block(k)])
merged_block_size = length(merged_ax[Block(j)])
offset + size_k ≤ merged_block_size ||
throw(ArgumentError("fine_ax blocks do not subdivide merged_ax blocks"))
J[k] = Block(j)[(offset + 1):(offset + size_k)]
offset += size_k
if offset == merged_block_size
j += 1
offset = 0
end
end
return J
end

using BlockArrays: AbstractBlockVector, Block

function checkindices(
a::GradedArray{<:Any, N}, I::NTuple{N, AbstractVector{<:BlockIndexRange{1}}}
) where {N}
for d in 1:N
nblocks_d = length(axes(a, d))
for bir in I[d]
Int(bir.block) ≤ nblocks_d ||
throw(BlockBoundsError(a, ntuple(i -> i == d ? bir : I[i][1], Val(N))))
end
end
return nothing
end

# Splitting: each I[d][k] = Block(b)[r] means dest block k comes from source block b
# at subrange r. This is the inverse of the merging getindex below.
function Base.getindex(
a::GradedArray{<:Any, N}, I::Vararg{AbstractVector{<:BlockIndexRange{1}}, N}
) where {N}
checkindices(a, I)
ax_dest = ntuple(d -> only(axes(axes(a, d)[I[d]])), Val(N))
a_dest = similar(a, ax_dest)
# Map source block b → list of (dest BlockIndexRange, src subrange).
# Stored blocks of a not referenced by I are skipped (partial block selection).
src_to_dests = ntuple(Val(N)) do d
key_type = Block{1, Int}
dest_bir_type = Base.promote_op(getindex, key_type, Base.OneTo{Int})
val_type = Tuple{dest_bir_type, UnitRange{Int}}
dict = Dict{key_type, Vector{val_type}}()
for k in eachindex(I[d])
bir = I[d][k]
b = Block(Int(bir.block))
r = only(bir.indices)
push!(get!(dict, b, val_type[]), (Block(k)[Base.axes1(r)], r))
end
return dict
end
for bI_src in eachblockstoredindex(a)
src_tuple = Tuple(bI_src)
all(d -> haskey(src_to_dests[d], src_tuple[d]), 1:N) || continue
dest_refs = ntuple(d -> src_to_dests[d][src_tuple[d]], Val(N))
for combo in Iterators.product(dest_refs...)
src_r = ntuple(d -> combo[d][2], Val(N))
src_data = @view(a[bI_src][src_r...])
iszero(src_data) && continue
dest_b = Block(ntuple(d -> only(Tuple(combo[d][1].block)), Val(N)))
a_dest_b = @view!(a_dest[dest_b])
dest_r = ntuple(d -> only(combo[d][1].indices), Val(N))
copyto!(@view(a_dest_b[dest_r...]), src_data)
end
end
return a_dest
end

# Merging: each I[d] groups source blocks into destination blocks.
function Base.getindex(
a::GradedArray{<:Any, N}, I::Vararg{AbstractBlockVector{<:Block{1}}, N}
) where {N}
ax_dest = ntuple(d -> Base.axes1(axes(a, d)[I[d]]), Val(N))
a_dest = similar(a, ax_dest)
ax = axes(a)
# Map source Block -> BlockIndexRange encoding dest block + subrange within it
src_to_dest = ntuple(Val(N)) do d
key_type = eltype(I[d])
range_type = UnitRange{Int}
val_type = Base.promote_op(getindex, key_type, range_type)
dict = Dict{key_type, val_type}()
for j in eachindex(blocks(I[d]))
sub_blocks = I[d][Block(j)]
start = 1
for b in sub_blocks
r = Base.OneTo(length(ax[d][b])) .+ (start - 1)
dict[b] = Block(j)[r]
start += length(r)
end
end
return dict
end
for bI_src in eachblockstoredindex(a)
src_tuple = Tuple(bI_src)
dest_info = ntuple(d -> src_to_dest[d][src_tuple[d]], Val(N))
dest_b = Block(map(di -> only(Tuple(di.block)), dest_info))
a_dest_b = @view!(a_dest[dest_b])
dest_r = map(di -> only(di.indices), dest_info)
copyto!(@view(a_dest_b[dest_r...]), a[bI_src])
end
return a_dest
end
2 changes: 1 addition & 1 deletion test/test_gradedarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
I = [Block(1)[1:1]]
@test_broken size(b[I, :]) == (1, 4)
@test_broken size(b[:, I]) == (4, 1)
@test_broken size(b[I, I]) == (1, 1)
@test size(b[I, I]) == (1, 1)
end
end
@testset "Matrix multiplication" begin
Expand Down
12 changes: 6 additions & 6 deletions test/test_show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,17 @@ end
@test sprint(show, g1) == "GradedUnitRange[$x => 2, $y => 3, $z => 2]"
@test sprint(show, MIME("text/plain"), g1) ==
"GradedUnitRange{$U1}\n" *
"sectorrange($x, 1:2)\n" *
"sectorrange($y, 3:5)\n" *
"sectorrange($z, 6:7)"
"sectorrange($x, Base.OneTo(2))\n" *
"sectorrange($y, Base.OneTo(3)) .+ 2\n" *
"sectorrange($z, Base.OneTo(2)) .+ 5"

g1d = dual(g1)
@test sprint(show, g1d) == "GradedUnitRange[$x' => 2, $y' => 3, $z' => 2]"
@test sprint(show, MIME("text/plain"), g1d) ==
"GradedUnitRange{$U1}\n" *
"sectorrange($x', 1:2)\n" *
"sectorrange($y', 3:5)\n" *
"sectorrange($z', 6:7)"
"sectorrange($x', Base.OneTo(2))\n" *
"sectorrange($y', Base.OneTo(3)) .+ 2\n" *
"sectorrange($z', Base.OneTo(2)) .+ 5"
end

@testset "show GradedArray" begin
Expand Down
8 changes: 4 additions & 4 deletions test/test_tensoralgebraext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using GradedArrays: GradedArray, GradedMatrix, SU2, SectorDelta, U1, dual, flip,
using Random: randn!
using TensorAlgebra:
FusionStyle, contract, matricize, tensor_product_axis, trivial_axis, unmatricize
using Test: @test, @testset
using Test: @test, @test_broken, @testset

function randn_blockdiagonal(elt::Type, axes::Tuple)
a = BlockSparseArray{elt}(undef, axes)
Expand Down Expand Up @@ -65,10 +65,10 @@ end
@test unmatricize(m, (U1(1), U1(1)), (U1(-2), U1(-1))) isa SectorDelta
end

broken = true
const contract_broken = true

const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
broken || @testset "`contract` `GradedArray` (eltype=$elt)" for elt in elts
@testset "`contract` `GradedArray` (eltype=$elt)" for elt in elts
@testset "matricize" begin
d1 = gradedrange([U1(0) => 1, U1(1) => 1])
d2 = gradedrange([U1(0) => 1, U1(1) => 1])
Expand Down Expand Up @@ -115,7 +115,7 @@ broken || @testset "`contract` `GradedArray` (eltype=$elt)" for elt in elts
@test a == unmatricize(m, (), (d1, d2, dual(d1), dual(d2)))
end

@testset "contract with U(1)" begin
contract_broken || @testset "contract with U(1)" begin
d = gradedrange([U1(0) => 2, U1(1) => 3])
a1 = randn_blockdiagonal(elt, (d, d, dual(d), dual(d)))
a2 = randn_blockdiagonal(elt, (d, d, dual(d), dual(d)))
Expand Down
Loading