Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NDTensors/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NDTensors"
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
authors = ["Matthew Fishman <[email protected]>"]
version = "0.4.1"
version = "0.4.2"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
1 change: 1 addition & 0 deletions NDTensors/src/blocksparse/block.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ sethash!(b::Block, h::UInt) = (b.hash[] = h; return b)
#

length(::Block{N}) where {N} = N
length(::Type{<:Block{N}}) where {N} = N

isless(b1::Block, b2::Block) = isless(Tuple(b1), Tuple(b2))

Expand Down
4 changes: 0 additions & 4 deletions NDTensors/src/blocksparse/blockoffsets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@ const BlockOffsets{N} = Dictionary{Block{N},Int}

BlockOffset(block::Block{N}, offset::Int) where {N} = BlockOffset{N}(block, offset)

Base.ndims(::Blocks{N}) where {N} = N
Base.ndims(::BlockOffset{N}) where {N} = N
Base.ndims(::BlockOffsets{N}) where {N} = N

blocktype(bofs::BlockOffsets) = keytype(bofs)

nzblock(bof::BlockOffset) = first(bof)
Expand Down
4 changes: 2 additions & 2 deletions NDTensors/src/blocksparse/contract_generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ function contract_blockoffsets(
indsR,
labelsR,
)
N1 = ndims(boffs1)
N2 = ndims(boffs2)
N1 = length(blocktype(boffs1))
N2 = length(blocktype(boffs2))
NR = length(labelsR)
ValNR = ValLength(labelsR)
labels1_to_labels2, labels1_to_labelsR, labels2_to_labelsR = contract_labels(
Expand Down
4 changes: 2 additions & 2 deletions NDTensors/src/blocksparse/contract_sequential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ function contract_blockoffsets(
indsR,
labelsR,
)
N1 = ndims(boffs1)
N2 = ndims(boffs2)
N1 = length(blocktype(boffs1))
N2 = length(blocktype(boffs2))
NR = length(labelsR)
ValNR = ValLength(labelsR)
labels1_to_labels2, labels1_to_labelsR, labels2_to_labelsR = contract_labels(
Expand Down
14 changes: 4 additions & 10 deletions NDTensors/src/blocksparse/diagblocksparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -448,22 +448,16 @@ function dense(
end

# convert to Dense
function dense(T::TensorT) where {TensorT<:DiagBlockSparseTensor}
R = zeros(dense(TensorT), inds(T))
for i in 1:diaglength(T)
setdiagindex!(R, getdiagindex(T, i), i)
end
return R
function dense(T::DiagBlockSparseTensor)
return dense(denseblocks(T))
end

# convert to BlockSparse
function denseblocks(D::Tensor)
nzblocksD = nzblocks(D)
T = BlockSparseTensor(eltype(D), nzblocksD, inds(D))
T = BlockSparseTensor(datatype(D), nzblocksD, inds(D))
for b in nzblocksD
for n in 1:diaglength(D)
setdiagindex!(T, getdiagindex(D, n), n)
end
T[b] = D[b]
end
return T
end
Expand Down
5 changes: 5 additions & 0 deletions NDTensors/src/dense/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ function copy(D::Dense)
return Dense(copy(expose(data(D))))
end

function Base.copyto!(R::Dense, T::Dense)
copyto!(expose(data(R)), expose(data(T)))
return R
end

function Base.real(T::Type{<:Dense})
return set_datatype(T, similartype(datatype(T), real(eltype(T))))
end
Expand Down
5 changes: 5 additions & 0 deletions NDTensors/src/dense/densetensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ end

convert(::Type{Array}, T::DenseTensor) = reshape(data(storage(T)), dims(inds(T)))

function Base.copyto!(R::DenseTensor, T::DenseTensor)
copyto!(storage(R), storage(T))
return R
end

# Create an Array that is a view of the Dense Tensor
# Useful for using Base Array functions
array(T::DenseTensor) = convert(Array, T)
Expand Down
5 changes: 5 additions & 0 deletions NDTensors/src/diag/diagtensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ Set the entire diagonal of a uniform DiagTensor.
"""
setdiag(T::UniformDiagTensor, val) = tensor(Diag(val), inds(T))

function Base.copyto!(R::DenseTensor, T::DiagTensor)
diagview(R) .= diagview(T)
return R
end

@propagate_inbounds function getindex(
T::DiagTensor{ElT,N}, inds::Vararg{Int,N}
) where {ElT,N}
Expand Down
4 changes: 3 additions & 1 deletion NDTensors/src/tensor/tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,9 @@ end

copy(T::Tensor) = setstorage(T, copy(storage(T)))

copyto!(R::Tensor, T::Tensor) = (copyto!(storage(R), storage(T)); R)
function copyto!(R::Tensor, T::Tensor)
return error("Not implemented.")
end

complex(T::Tensor) = setstorage(T, complex(storage(T)))

Expand Down
5 changes: 2 additions & 3 deletions NDTensors/src/tensorstorage/tensorstorage.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,8 @@ Base.real(S::TensorStorage) = setdata(S, real(data(S)))

Base.imag(S::TensorStorage) = setdata(S, imag(data(S)))

function copyto!(S1::TensorStorage, S2::TensorStorage)
copyto!(expose(data(S1)), expose(data(S2)))
return S1
function Base.copyto!(S1::TensorStorage, S2::TensorStorage)
return error("Not implemented.")
end

Random.randn!(S::TensorStorage) = randn!(Random.default_rng(), S)
Expand Down
28 changes: 25 additions & 3 deletions NDTensors/test/test_diagblocksparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using NDTensors:
blockoffsets,
contract,
dense,
denseblocks,
inds,
nzblocks
using Random: randn!
Expand Down Expand Up @@ -78,11 +79,32 @@ using .NDTensorsTestUtils: devices_list
A = BlockSparseTensor{elt}([(1, 1), (2, 2)], [3, 2, 3], [2, 2])
randn!(A)
t = Tensor(DiagBlockSparse(one(elt), blockoffsets(A)), inds(A))
@test_broken dense(contract(A, (1, -2), (t), (3, -2))) ≈
@test dense(contract(A, (1, -2), (t), (3, -2))) ≈
contract(dense(A), (1, -2), dense(t), (3, -2))
@test_broken dense(contract(A, (-2, 1), t, (-2, 3))) ≈
@test dense(contract(A, (-2, 1), t, (-2, 3))) ≈
contract(dense(A), (-2, 1), dense(t), (-2, 3))
@test_broken contract(dev(A), (-1, -2), dev(t), (-1, -2))[] ≈
@test contract(dev(A), (-1, -2), dev(t), (-1, -2))[] ≈
contract(dense(A), (-1, -2), dense(t), (-1, -2))[]
end

@testset "DiagBlockSparse denseblocks" begin
elt = Float64
blockoffsets_a = Dictionary([Block(1, 1), Block(2, 2)], [0, 2])
inds_a = ([2, 2], [2, 2])
a = Tensor(DiagBlockSparse(elt, blockoffsets_a, 4), inds_a)
a[Block(1, 1)][1, 1] = 1
a[Block(1, 1)][2, 2] = 2
a[Block(2, 2)][1, 1] = 3
a[Block(2, 2)][2, 2] = 4
a′ = denseblocks(a)
@test dense(a) == dense(a′)

elt = Float64
blockoffsets_a = Dictionary([Block(1, 1)], [0])
inds_a = ([2], [1, 1])
a = Tensor(DiagBlockSparse(one(elt), blockoffsets_a), inds_a)
a′ = denseblocks(a)
@test dense(a) == dense(a′)
end

end
Loading