Skip to content

Commit e2e6778

Browse files
committed
Fix tests
1 parent 342e37b commit e2e6778

File tree

8 files changed

+35
-12
lines changed

8 files changed

+35
-12
lines changed

NDTensors/src/blocksparse/diagblocksparse.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ end
455455
# convert to BlockSparse
456456
function denseblocks(D::Tensor)
457457
nzblocksD = nzblocks(D)
458-
T = BlockSparseTensor(eltype(D), nzblocksD, inds(D))
458+
T = BlockSparseTensor(datatype(D), nzblocksD, inds(D))
459459
for b in nzblocksD
460460
T[b] = D[b]
461461
end

NDTensors/src/dense/dense.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,11 @@ function copy(D::Dense)
105105
return Dense(copy(expose(data(D))))
106106
end
107107

108+
function Base.copyto!(R::Dense, T::Dense)
109+
copyto!(expose(data(R)), expose(data(T)))
110+
return R
111+
end
112+
108113
function Base.real(T::Type{<:Dense})
109114
return set_datatype(T, similartype(datatype(T), real(eltype(T))))
110115
end

NDTensors/src/dense/densetensor.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ end
6666

6767
convert(::Type{Array}, T::DenseTensor) = reshape(data(storage(T)), dims(inds(T)))
6868

69+
function Base.copyto!(R::DenseTensor, T::DenseTensor)
70+
copyto!(storage(R), storage(T))
71+
return R
72+
end
73+
6974
# Create an Array that is a view of the Dense Tensor
7075
# Useful for using Base Array functions
7176
array(T::DenseTensor) = convert(Array, T)

NDTensors/src/diag/diagtensor.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ Set the entire diagonal of a uniform DiagTensor.
7070
"""
7171
setdiag(T::UniformDiagTensor, val) = tensor(Diag(val), inds(T))
7272

73+
function Base.copyto!(R::DenseTensor, T::DiagTensor)
74+
diagview(R) .= diagview(T)
75+
return R
76+
end
77+
7378
@propagate_inbounds function getindex(
7479
T::DiagTensor{ElT,N}, inds::Vararg{Int,N}
7580
) where {ElT,N}

NDTensors/src/lib/Expose/src/exposed.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,6 @@ end
66

77
expose(object) = Exposed{unwrap_array_type(object),typeof(object)}(object)
88

9-
# This is a corner case that handles the fact that by convention,
10-
# the storage of a uniform diagonaly tensor in NDTensors.jl is a number.
11-
expose(object::Number) = Exposed{typeof(object),typeof(object)}(object)
12-
139
unexpose(E::Exposed) = E.object
1410

1511
## TODO remove TypeParameterAccessors when SetParameters is removed

NDTensors/src/tensor/tensor.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,9 @@ end
204204

205205
copy(T::Tensor) = setstorage(T, copy(storage(T)))
206206

207-
copyto!(R::Tensor, T::Tensor) = (copyto!(storage(R), storage(T)); R)
207+
function copyto!(R::Tensor, T::Tensor)
208+
return error("Not implemented.")
209+
end
208210

209211
complex(T::Tensor) = setstorage(T, complex(storage(T)))
210212

NDTensors/src/tensorstorage/tensorstorage.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,8 @@ Base.real(S::TensorStorage) = setdata(S, real(data(S)))
6262

6363
Base.imag(S::TensorStorage) = setdata(S, imag(data(S)))
6464

65-
function copyto!(S1::TensorStorage, S2::TensorStorage)
66-
copyto!(expose(data(S1)), expose(data(S2)))
67-
return S1
65+
function Base.copyto!(S1::TensorStorage, S2::TensorStorage)
66+
return error("Not implemented.")
6867
end
6968

7069
Random.randn!(S::TensorStorage) = randn!(Random.default_rng(), S)

NDTensors/test/test_diagblocksparse.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,26 @@ using .NDTensorsTestUtils: devices_list
7979
A = BlockSparseTensor{elt}([(1, 1), (2, 2)], [3, 2, 3], [2, 2])
8080
randn!(A)
8181
t = Tensor(DiagBlockSparse(one(elt), blockoffsets(A)), inds(A))
82-
@test_broken dense(contract(A, (1, -2), (t), (3, -2)))
82+
@test dense(contract(A, (1, -2), (t), (3, -2)))
8383
contract(dense(A), (1, -2), dense(t), (3, -2))
84-
@test_broken dense(contract(A, (-2, 1), t, (-2, 3)))
84+
@test dense(contract(A, (-2, 1), t, (-2, 3)))
8585
contract(dense(A), (-2, 1), dense(t), (-2, 3))
86-
@test_broken contract(dev(A), (-1, -2), dev(t), (-1, -2))[]
86+
@test contract(dev(A), (-1, -2), dev(t), (-1, -2))[]
8787
contract(dense(A), (-1, -2), dense(t), (-1, -2))[]
8888
end
8989

9090
@testset "DiagBlockSparse denseblocks" begin
91+
elt = Float64
92+
blockoffsets_a = Dictionary([Block(1, 1), Block(2, 2)], [0, 2])
93+
inds_a = ([2, 2], [2, 2])
94+
a = Tensor(DiagBlockSparse(elt, blockoffsets_a, 4), inds_a)
95+
a[Block(1, 1)][1, 1] = 1
96+
a[Block(1, 1)][2, 2] = 2
97+
a[Block(2, 2)][1, 1] = 3
98+
a[Block(2, 2)][2, 2] = 4
99+
a′ = denseblocks(a)
100+
@test dense(a) == dense(a′)
101+
91102
elt = Float64
92103
blockoffsets_a = Dictionary([Block(1, 1)], [0])
93104
inds_a = ([2], [1, 1])

0 commit comments

Comments
 (0)