Skip to content

Commit 1f7ce5c

Browse files
committed
More test via CPU
1 parent b673271 commit 1f7ce5c

File tree

2 files changed

+7
-13
lines changed

2 files changed

+7
-13
lines changed

src/tensors/tensor.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ for randf in (:rand, :randn, :randexp, :randisometry)
403403
end
404404
end
405405

406-
# Collecting arbitrary TensorMaps
406+
# Collecting arbitrary TensorMaps
407407
#-----------------------------
408408
Base.collect(t::TensorMap) = convert(TensorMapWithStorage{scalartype(t), similarstoragetype(scalartype(t))}, t)
409409

test/cuda/tensors.jl

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -289,20 +289,14 @@ for V in spacelist
289289
for p in permutations(1:5)
290290
p1 = ntuple(n -> p[n], k)
291291
p2 = ntuple(n -> p[k + n], 5 - k)
292-
t2 = CUDA.@allowscalar permute(t, (p1, p2))
293-
a2 = convert(Array, collect(t2))
294-
@test a2 permutedims(a, (p1..., p2...))
295-
CUDA.@allowscalar begin
296-
@test convert(Array, collect(transpose(t2)))
297-
permutedims(a2, (5, 4, 3, 2, 1))
298-
end
292+
dt2 = CUDA.@allowscalar permute(t, (p1, p2))
293+
ht2 = permute(collect(t), (p1, p2))
294+
@test ht2 == collect(dt2)
299295
end
300296

301-
t3 = CUDA.@allowscalar repartition(t, k)
302-
a3 = convert(Array, collect(t3))
303-
@test a3 permutedims(
304-
a, (ntuple(identity, k)..., reverse(ntuple(i -> i + k, 5 - k))...)
305-
)
297+
dt3 = CUDA.@allowscalar repartition(t, k)
298+
ht3 = repartition(collect(t), k)
299+
@test ht3 == collect(dt3)
306300
end
307301
end
308302
end

0 commit comments

Comments
 (0)