Skip to content

Commit 1f0dea3

Browse files
committed
Try fixing tests
1 parent ba9c22f commit 1f0dea3

File tree

2 files changed

+128
-2
lines changed

2 files changed

+128
-2
lines changed

README.md

Lines changed: 125 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle)
88
[![Aqua](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)
99

10+
A block sparse array type in Julia based on the [`BlockArrays.jl`](https://github.com/JuliaArrays/BlockArrays.jl) interface.
11+
1012
## Installation instructions
1113

1214
This package resides in the `ITensor/ITensorRegistry` local registry.
@@ -32,10 +34,131 @@ julia> Pkg.add("BlockSparseArrays")
3234
## Examples
3335

3436
````julia
35-
using BlockSparseArrays: BlockSparseArrays
37+
using BlockArrays: BlockArrays, BlockedVector, Block, blockedrange
38+
using BlockSparseArrays: BlockSparseArray, block_stored_length
39+
using Test: @test, @test_broken
40+
41+
function main()
42+
# Block dimensions
43+
i1 = [2, 3]
44+
i2 = [2, 3]
45+
46+
i_axes = (blockedrange(i1), blockedrange(i2))
47+
48+
function block_size(axes, block)
49+
return length.(getindex.(axes, Block.(block.n)))
50+
end
51+
52+
# Data
53+
nz_blocks = Block.([(1, 1), (2, 2)])
54+
nz_block_sizes = [block_size(i_axes, nz_block) for nz_block in nz_blocks]
55+
nz_block_lengths = prod.(nz_block_sizes)
56+
57+
# Blocks with contiguous underlying data
58+
d_data = BlockedVector(randn(sum(nz_block_lengths)), nz_block_lengths)
59+
d_blocks = [
60+
reshape(@view(d_data[Block(i)]), block_size(i_axes, nz_blocks[i])) for
61+
i in 1:length(nz_blocks)
62+
]
63+
b = BlockSparseArray(nz_blocks, d_blocks, i_axes)
64+
65+
@test block_stored_length(b) == 2
66+
67+
# Blocks with discontiguous underlying data
68+
d_blocks = randn.(nz_block_sizes)
69+
b = BlockSparseArray(nz_blocks, d_blocks, i_axes)
70+
71+
@test block_stored_length(b) == 2
72+
73+
# Access a block
74+
@test b[Block(1, 1)] == d_blocks[1]
75+
76+
# Access a zero block, returns a zero matrix
77+
@test b[Block(1, 2)] == zeros(2, 3)
78+
79+
# Set a zero block
80+
a₁₂ = randn(2, 3)
81+
b[Block(1, 2)] = a₁₂
82+
@test b[Block(1, 2)] == a₁₂
83+
84+
# Matrix multiplication
85+
# TODO: Fix this, broken.
86+
@test_broken b * b Array(b) * Array(b)
87+
88+
permuted_b = permutedims(b, (2, 1))
89+
@test permuted_b isa BlockSparseArray
90+
@test permuted_b == permutedims(Array(b), (2, 1))
91+
92+
@test b + b Array(b) + Array(b)
93+
@test b + b isa BlockSparseArray
94+
# TODO: Fix this, broken.
95+
@test_broken block_stored_length(b + b) == 2
96+
97+
scaled_b = 2b
98+
@test scaled_b 2Array(b)
99+
@test scaled_b isa BlockSparseArray
100+
101+
# TODO: Fix this, broken.
102+
@test_broken reshape(b, ([4, 6, 6, 9],)) isa BlockSparseArray{<:Any,1}
103+
104+
return nothing
105+
end
106+
107+
main()
36108
````
37109

38-
Examples go here.
110+
# BlockSparseArrays.jl and BlockArrays.jl interface
111+
112+
````julia
113+
using BlockArrays: BlockArrays, Block
114+
using BlockSparseArrays: BlockSparseArray
115+
116+
i1 = [2, 3]
117+
i2 = [2, 3]
118+
B = BlockSparseArray{Float64}(i1, i2)
119+
B[Block(1, 1)] = randn(2, 2)
120+
B[Block(2, 2)] = randn(3, 3)
121+
122+
# Minimal interface
123+
124+
# Specifies the block structure
125+
@show collect.(BlockArrays.blockaxes(axes(B, 1)))
126+
127+
# Index range of a block
128+
@show axes(B, 1)[Block(1)]
129+
130+
# Last index of each block
131+
@show BlockArrays.blocklasts(axes(B, 1))
132+
133+
# Find the block containing the index
134+
@show BlockArrays.findblock(axes(B, 1), 3)
135+
136+
# Retrieve a block
137+
@show B[Block(1, 1)]
138+
@show BlockArrays.viewblock(B, Block(1, 1))
139+
140+
# Check block bounds
141+
@show BlockArrays.blockcheckbounds(B, 2, 2)
142+
@show BlockArrays.blockcheckbounds(B, Block(2, 2))
143+
144+
# Derived interface
145+
146+
# Specifies the block structure
147+
@show collect(Iterators.product(BlockArrays.blockaxes(B)...))
148+
149+
# Iterate over block views
150+
@show sum.(BlockArrays.eachblock(B))
151+
152+
# Reshape into 1-d
153+
# TODO: Fix this, broken.
154+
# @show BlockArrays.blockvec(B)[Block(1)]
155+
156+
# Array-of-array view
157+
@show BlockArrays.blocks(B)[1, 1] == B[Block(1, 1)]
158+
159+
# Access an index within a block
160+
@show B[Block(1, 1)[1, 1]] == B[1, 1]
161+
````
39162

40163
---
41164

test/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
23
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
34
BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
45
BroadcastMapConversion = "4a4adec5-520f-4750-bb37-d5e66b4ddeb2"
@@ -11,7 +12,9 @@ NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
1112
NestedPermutedDimsArrays = "2c2a8ec4-3cfc-4276-aa3e-1307b4294e58"
1213
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1314
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
15+
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
1416
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
17+
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
1518
SymmetrySectors = "f8a8ad64-adbc-4fce-92f7-ffe2bb36a86e"
1619
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
1720
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

0 commit comments

Comments
 (0)