Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BlockSparseArrays"
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
version = "0.10.12"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.10.13"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -22,11 +22,9 @@ TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"

[weakdeps]
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"

[extensions]
BlockSparseArraysTensorAlgebraExt = "TensorAlgebra"
BlockSparseArraysTensorProductsExt = "TensorProducts"

[compat]
Adapt = "4.1.1"
Expand All @@ -44,8 +42,7 @@ MapBroadcast = "0.1.5"
MatrixAlgebraKit = "0.6"
SparseArraysBase = "0.7.1"
SplitApplyCombine = "1.2.3"
TensorAlgebra = "0.5"
TensorProducts = "0.1.7"
TensorAlgebra = "0.6"
Test = "1.10"
TypeParameterAccessors = "0.4.1"
julia = "1.10"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,25 +1,29 @@
module BlockSparseArraysTensorAlgebraExt

using BlockSparseArrays: AbstractBlockSparseArray, blockreshape
using TensorAlgebra: TensorAlgebra, BlockedTuple, FusionStyle, fuseaxes

struct BlockReshapeFusion <: FusionStyle end
using BlockArrays: Block, blocklength, blocks, eachblockaxes1
using BlockSparseArrays: AbstractBlockSparseArray, AbstractBlockSparseMatrix,
BlockUnitRange, blockrange, blocksparse
using SparseArraysBase: eachstoredindex
using TensorAlgebra: TensorAlgebra, BlockReshapeFusion, BlockedTuple, matricize,
matricize_axes, tensor_product_axis, unmatricize

function TensorAlgebra.FusionStyle(::Type{<:AbstractBlockSparseArray})
return BlockReshapeFusion()
function TensorAlgebra.tensor_product_axis(
::BlockReshapeFusion, r1::BlockUnitRange, r2::BlockUnitRange
)
isone(first(r1)) || isone(first(r2)) ||
throw(ArgumentError("Only one-based axes are supported"))
blockaxpairs = Iterators.product(eachblockaxes1(r1), eachblockaxes1(r2))
blockaxs = vec(splat(tensor_product_axis).(blockaxpairs))
return blockrange(blockaxs)
end

using BlockArrays: Block, blocklength, blocks
using BlockSparseArrays: blocksparse
using SparseArraysBase: eachstoredindex
using TensorAlgebra: TensorAlgebra, matricize, unmatricize
function TensorAlgebra.matricize(
::BlockReshapeFusion, a::AbstractArray, length1::Val, length2::Val
style::BlockReshapeFusion, a::AbstractBlockSparseArray, length_codomain::Val
)
ax = fuseaxes(axes(a), length1, length2)
reshaped_blocks_a = reshape(blocks(a), map(blocklength, ax))
ax = matricize_axes(style, a, length_codomain)
reshaped_blocks_a = reshape(blocks(a), blocklength.(ax))
key(I) = Block(Tuple(I))
value(I) = matricize(reshaped_blocks_a[I], length1, length2)
value(I) = matricize(reshaped_blocks_a[I], length_codomain)
Is = eachstoredindex(reshaped_blocks_a)
bs = if isempty(Is)
# Catch empty case and make sure the type is constrained properly.
Expand All @@ -35,16 +39,16 @@ function TensorAlgebra.matricize(
return blocksparse(bs, ax)
end

using BlockArrays: blocklengths
function TensorAlgebra.unmatricize(
::BlockReshapeFusion,
m::AbstractMatrix,
m::AbstractBlockSparseMatrix,
codomain_axes::Tuple{Vararg{AbstractUnitRange}},
domain_axes::Tuple{Vararg{AbstractUnitRange}},
)
ax = (codomain_axes..., domain_axes...)
reshaped_blocks_m = reshape(blocks(m), map(blocklength, ax))
function f(I)
reshaped_blocks_m = reshape(blocks(m), blocklength.(ax))
key(I) = Block(Tuple(I))
function value(I)
block_axes_I = BlockedTuple(
map(ntuple(identity, length(ax))) do i
return Base.axes1(ax[i][Block(I[i])])
Expand All @@ -53,7 +57,7 @@ function TensorAlgebra.unmatricize(
)
return unmatricize(reshaped_blocks_m[I], block_axes_I)
end
bs = Dict(Block(Tuple(I)) => f(I) for I in eachstoredindex(reshaped_blocks_m))
bs = Dict(key(I) => value(I) for I in eachstoredindex(reshaped_blocks_m))
return blocksparse(bs, ax)
end

Expand Down

This file was deleted.

2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ SafeTestsets = "0.1"
SparseArraysBase = "0.7"
StableRNGs = "1"
Suppressor = "0.2"
TensorAlgebra = "0.5"
TensorAlgebra = "0.6"
Test = "1"
TestExtras = "0.3"
TypeParameterAccessors = "0.4"
Loading