We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent b673271 commit 1f7ce5cCopy full SHA for 1f7ce5c
src/tensors/tensor.jl
@@ -403,7 +403,7 @@ for randf in (:rand, :randn, :randexp, :randisometry)
403
end
404
405
406
-# Collecting arbitrary TensorMaps
+# Collecting arbitrary TensorMaps
407
#-----------------------------
408
Base.collect(t::TensorMap) = convert(TensorMapWithStorage{scalartype(t), similarstoragetype(scalartype(t))}, t)
409
test/cuda/tensors.jl
@@ -289,20 +289,14 @@ for V in spacelist
289
for p in permutations(1:5)
290
p1 = ntuple(n -> p[n], k)
291
p2 = ntuple(n -> p[k + n], 5 - k)
292
- t2 = CUDA.@allowscalar permute(t, (p1, p2))
293
- a2 = convert(Array, collect(t2))
294
- @test a2 ≈ permutedims(a, (p1..., p2...))
295
- CUDA.@allowscalar begin
296
- @test convert(Array, collect(transpose(t2))) ≈
297
- permutedims(a2, (5, 4, 3, 2, 1))
298
- end
+ dt2 = CUDA.@allowscalar permute(t, (p1, p2))
+ ht2 = permute(collect(t), (p1, p2))
+ @test ht2 == collect(dt2)
299
300
301
- t3 = CUDA.@allowscalar repartition(t, k)
302
- a3 = convert(Array, collect(t3))
303
- @test a3 ≈ permutedims(
304
- a, (ntuple(identity, k)..., reverse(ntuple(i -> i + k, 5 - k))...)
305
- )
+ dt3 = CUDA.@allowscalar repartition(t, k)
+ ht3 = repartition(collect(t), k)
+ @test ht3 == collect(dt3)
306
307
308
0 commit comments