Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BlockSparseArrays"
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.4.1"
version = "0.4.2"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down Expand Up @@ -40,7 +40,7 @@ MacroTools = "0.5.13"
MapBroadcast = "0.1.5"
SparseArraysBase = "0.5"
SplitApplyCombine = "1.2.3"
TensorAlgebra = "0.2.4"
TensorAlgebra = "0.3"
Test = "1.10"
TypeParameterAccessors = "0.2.0, 0.3"
julia = "1.10"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,34 @@
module BlockSparseArraysTensorAlgebraExt

using BlockArrays: AbstractBlockedUnitRange
using BlockSparseArrays: AbstractBlockSparseArray, blockreshape
using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion
using TensorAlgebra:
TensorAlgebra,
AbstractBlockPermutation,
BlockedTuple,
FusionStyle,
ReshapeFusion,
fuseaxes

TensorAlgebra.FusionStyle(::AbstractBlockedUnitRange) = BlockReshapeFusion()
struct BlockReshapeFusion <: FusionStyle end

function TensorAlgebra.fusedims(
::BlockReshapeFusion, a::AbstractArray, axes::AbstractUnitRange...
function TensorAlgebra.FusionStyle(::AbstractBlockSparseArray, ::ReshapeFusion)
return BlockReshapeFusion()

Check warning on line 15 in ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl#L14-L15

Added lines #L14 - L15 were not covered by tests
end

function TensorAlgebra.matricize(

Check warning on line 18 in ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl#L18

Added line #L18 was not covered by tests
::BlockReshapeFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2}
)
return blockreshape(a, axes)
a_perm = permutedims(a, Tuple(biperm))
new_axes = fuseaxes(axes(a_perm), biperm)
return blockreshape(a_perm, new_axes)

Check warning on line 23 in ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl#L21-L23

Added lines #L21 - L23 were not covered by tests
end

function TensorAlgebra.splitdims(
::BlockReshapeFusion, a::AbstractArray, axes::AbstractUnitRange...
function TensorAlgebra.unmatricize(

Check warning on line 26 in ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl#L26

Added line #L26 was not covered by tests
::BlockReshapeFusion,
m::AbstractMatrix,
blocked_axes::BlockedTuple{2,<:Any,<:Tuple{Vararg{AbstractUnitRange}}},
)
return blockreshape(a, axes)
return blockreshape(m, Tuple(blocked_axes)...)

Check warning on line 31 in ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl#L31

Added line #L31 was not covered by tests
end

end
8 changes: 0 additions & 8 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,12 @@ BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
GradedUnitRanges = "e2de450a-8a67-46c7-b59c-01d5a3d041c5"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
LabelledNumbers = "f856a3a6-4152-4ec4-b2a7-02c1a55d7993"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
SymmetrySectors = "f8a8ad64-adbc-4fce-92f7-ffe2bb36a86e"
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
Expand All @@ -29,16 +25,12 @@ BlockArrays = "1"
BlockSparseArrays = "0.4"
DiagonalArrays = "0.3"
GPUArraysCore = "0.2"
GradedUnitRanges = "0.2.2"
JLArrays = "0.2"
LabelledNumbers = "0.1"
LinearAlgebra = "1"
Pkg = "1"
Random = "1"
SafeTestsets = "0.1"
SparseArraysBase = "0.5"
Suppressor = "0.2"
SymmetrySectors = "0.1.7"
TensorAlgebra = "0.2.4"
Test = "1"
TestExtras = "0.3"
Expand Down
21 changes: 2 additions & 19 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@ using BlockArrays:
BlockArrays,
Block,
BlockArray,
BlockIndexRange,
BlockRange,
BlockSlice,
BlockVector,
BlockedOneTo,
BlockedUnitRange,
BlockedArray,
BlockedVector,
blockedrange,
Expand All @@ -35,9 +32,8 @@ using BlockSparseArrays:
view!
using GPUArraysCore: @allowscalar
using JLArrays: JLArray, JLMatrix
using LinearAlgebra: Adjoint, Transpose, dot, mul!, norm
using LinearAlgebra: Adjoint, Transpose, dot, norm
using SparseArraysBase: SparseArrayDOK, SparseMatrixDOK, SparseVectorDOK, storedlength
using TensorAlgebra: contract
using Test: @test, @test_broken, @test_throws, @testset, @inferred
using TestExtras: @constinferred
using TypeParameterAccessors: TypeParameterAccessors, Position
Expand Down Expand Up @@ -1120,20 +1116,7 @@ arrayts = (Array, JLArray)
@test a_dest[Block(2, 1)] == a1[Block(2, 1)]
@test a_dest[Block(3, 4)] == a2[Block(1, 2)]
end
@testset "TensorAlgebra" begin
a1 = dev(BlockSparseArray{elt}(undef, [2, 3], [2, 3]))
a1[Block(1, 1)] = dev(randn(elt, size(@view(a1[Block(1, 1)]))))
a2 = dev(BlockSparseArray{elt}(undef, [2, 3], [2, 3]))
a2[Block(1, 1)] = dev(randn(elt, size(@view(a1[Block(1, 1)]))))
# TODO: Make this work, requires customization of `TensorAlgebra.fusedims` and
# `TensorAlgebra.splitdims` in terms of `BlockSparseArrays.blockreshape`,
# and customization of `TensorAlgebra.:⊗` in terms of `GradedUnitRanges.tensor_product`.
a_dest, dimnames_dest = contract(a1, (1, -1), a2, (-1, 2))
@allowscalar begin
a_dest_dense, dimnames_dest_dense = contract(Array(a1), (1, -1), Array(a2), (-1, 2))
@test a_dest ≈ a_dest_dense
end
end

@testset "blockreshape" begin
a = dev(BlockSparseArray{elt}(undef, ([3, 4], [2, 3])))
a[Block(1, 2)] = dev(randn(elt, size(@view(a[Block(1, 2)]))))
Expand Down
8 changes: 4 additions & 4 deletions test/test_tensoralgebraext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,12 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test a_dest ≈ a_dest_dense

# matrix vector
@test_broken a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2))
#=
a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2))
a_dest_dense, dimnames_dest_dense = contract(a1_dense, (2, -1, -2, 1), a3_dense, (1, 2))
@test dimnames_dest == dimnames_dest_dense
@test size(a_dest) == size(a_dest_dense)
@test a_dest isa BlockSparseArray
@test a_dest ≈ a_dest_dense
=#

# vector matrix
a_dest, dimnames_dest = contract(a3, (1, 2), a1, (2, -1, -2, 1))
Expand All @@ -54,12 +52,14 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test a_dest ≈ a_dest_dense

# vector vector
@test_broken a_dest, dimnames_dest = contract(a3, (1, 2), a3, (2, 1))
#=
a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (2, 1))
a_dest, dimnames_dest = contract(a3, (1, 2), a3, (2, 1))
@test dimnames_dest == dimnames_dest_dense
@test size(a_dest) == size(a_dest_dense)
@test a_dest isa BlockSparseArray{elt,0}
@test a_dest ≈ a_dest_dense
=#

# outer product
a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (3, 4))
Expand Down
Loading