Skip to content

Commit 342e37b

Browse files
committed
[NDTensors] Fix denseblocks
1 parent 6639166 commit 342e37b

File tree

7 files changed

+22
-16
lines changed

7 files changed

+22
-16
lines changed

NDTensors/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NDTensors"
22
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
33
authors = ["Matthew Fishman <[email protected]>"]
4-
version = "0.4.1"
4+
version = "0.4.2"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"

NDTensors/src/blocksparse/block.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ sethash!(b::Block, h::UInt) = (b.hash[] = h; return b)
7575
#
7676

7777
length(::Block{N}) where {N} = N
78+
length(::Type{<:Block{N}}) where {N} = N
7879

7980
isless(b1::Block, b2::Block) = isless(Tuple(b1), Tuple(b2))
8081

NDTensors/src/blocksparse/blockoffsets.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,6 @@ const BlockOffsets{N} = Dictionary{Block{N},Int}
1212

1313
BlockOffset(block::Block{N}, offset::Int) where {N} = BlockOffset{N}(block, offset)
1414

15-
Base.ndims(::Blocks{N}) where {N} = N
16-
Base.ndims(::BlockOffset{N}) where {N} = N
17-
Base.ndims(::BlockOffsets{N}) where {N} = N
18-
1915
blocktype(bofs::BlockOffsets) = keytype(bofs)
2016

2117
nzblock(bof::BlockOffset) = first(bof)

NDTensors/src/blocksparse/contract_sequential.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ function contract_blockoffsets(
99
indsR,
1010
labelsR,
1111
)
12-
N1 = ndims(boffs1)
13-
N2 = ndims(boffs2)
12+
N1 = length(keytype(boffs1))
13+
N2 = length(keytype(boffs2))
1414
NR = length(labelsR)
1515
ValNR = ValLength(labelsR)
1616
labels1_to_labels2, labels1_to_labelsR, labels2_to_labelsR = contract_labels(

NDTensors/src/blocksparse/diagblocksparse.jl

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -448,22 +448,16 @@ function dense(
448448
end
449449

450450
# convert to Dense
451-
function dense(T::TensorT) where {TensorT<:DiagBlockSparseTensor}
452-
R = zeros(dense(TensorT), inds(T))
453-
for i in 1:diaglength(T)
454-
setdiagindex!(R, getdiagindex(T, i), i)
455-
end
456-
return R
451+
function dense(T::DiagBlockSparseTensor)
452+
return dense(denseblocks(T))
457453
end
458454

459455
# convert to BlockSparse
460456
function denseblocks(D::Tensor)
461457
nzblocksD = nzblocks(D)
462458
T = BlockSparseTensor(eltype(D), nzblocksD, inds(D))
463459
for b in nzblocksD
464-
for n in 1:diaglength(D)
465-
setdiagindex!(T, getdiagindex(D, n), n)
466-
end
460+
T[b] = D[b]
467461
end
468462
return T
469463
end

NDTensors/src/lib/Expose/src/exposed.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ end
66

77
expose(object) = Exposed{unwrap_array_type(object),typeof(object)}(object)
88

9+
# This is a corner case that handles the fact that by convention,
10+
# the storage of a uniform diagonaly tensor in NDTensors.jl is a number.
11+
expose(object::Number) = Exposed{typeof(object),typeof(object)}(object)
12+
913
unexpose(E::Exposed) = E.object
1014

1115
## TODO remove TypeParameterAccessors when SetParameters is removed

NDTensors/test/test_diagblocksparse.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ using NDTensors:
1111
blockoffsets,
1212
contract,
1313
dense,
14+
denseblocks,
1415
inds,
1516
nzblocks
1617
using Random: randn!
@@ -85,4 +86,14 @@ using .NDTensorsTestUtils: devices_list
8586
@test_broken contract(dev(A), (-1, -2), dev(t), (-1, -2))[]
8687
contract(dense(A), (-1, -2), dense(t), (-1, -2))[]
8788
end
89+
90+
@testset "DiagBlockSparse denseblocks" begin
91+
elt = Float64
92+
blockoffsets_a = Dictionary([Block(1, 1)], [0])
93+
inds_a = ([2], [1, 1])
94+
a = Tensor(DiagBlockSparse(one(elt), blockoffsets_a), inds_a)
95+
a′ = denseblocks(a)
96+
@test dense(a) == dense(a′)
97+
end
98+
8899
end

0 commit comments

Comments
 (0)