Skip to content

Commit b673271

Browse files
committed
move collect in
1 parent 5c8121a commit b673271

File tree

3 files changed

+18
-18
lines changed

3 files changed

+18
-18
lines changed

ext/TensorKitCUDAExt/cutensormap.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,6 @@ function CuTensorMap(t::TensorMap{T, S, N₁, N₂, A}) where {T, S, N₁, N₂,
77
return CuTensorMap{T, S, N₁, N₂}(CuArray{T}(t.data), space(t))
88
end
99

10-
function Base.collect(t::CuTensorMap{T}) where {T}
11-
return convert(TensorKit.TensorMapWithStorage{T, Vector{T}}, t)
12-
end
13-
1410
# project_symmetric! doesn't yet work for GPU types, so do this on the host, then copy
1511
function TensorKit.project_symmetric_and_check(::Type{T}, ::Type{A}, data::AbstractArray, V::TensorMapSpace; tol = sqrt(eps(real(float(eltype(data)))))) where {T, A <: CuVector{T}}
1612
h_t = TensorKit.TensorMapWithStorage{T, Vector{T}}(undef, V)

src/tensors/tensor.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,10 @@ for randf in (:rand, :randn, :randexp, :randisometry)
403403
end
404404
end
405405

406+
# Collecting arbitrary TensorMaps
407+
#-----------------------------
408+
Base.collect(t::TensorMap) = convert(TensorMapWithStorage{scalartype(t), similarstoragetype(scalartype(t))}, t)
409+
406410
# Efficient copy constructors
407411
#-----------------------------
408412
Base.copy(t::TensorMap) = typeof(t)(copy(t.data), t.space)

test/cuda/tensors.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -152,16 +152,14 @@ for V in spacelist
152152
@test dot(t2, t) conj(dot(t2', t'))
153153
@test dot(t2, t) dot(t', t2')
154154

155-
i1 = @constinferred(isomorphism(T, V1 V2, V2 V1))
156-
i2 = @constinferred(isomorphism(CuVector{T}, V2 V1, V1 V2))
157-
CUDA.@allowscalar begin
158-
@test i1 * i2 == @constinferred(id(CuVector{T, CUDA.DeviceMemory}, V1 V2))
159-
@test i2 * i1 == @constinferred(id(CuVector{T, CUDA.DeviceMemory}, V2 V1))
160-
w = @constinferred(isometry(CuVector{T, CUDA.DeviceMemory}, V1 (oneunit(V1) oneunit(V1)), V1))
161-
@test dim(w) == 2 * dim(V1 V1)
162-
@test w' * w == id(CuVector{T, CUDA.DeviceMemory}, V1)
163-
@test w * w' == (w * w')^2
164-
end
155+
i1 = @constinferred(isomorphism(CuVector{T, CUDA.DeviceMemory}, V1 V2, V2 V1))
156+
i2 = @constinferred(isomorphism(CuVector{T, CUDA.DeviceMemory}, V2 V1, V1 V2))
157+
@test i1 * i2 == @constinferred(id(CuVector{T, CUDA.DeviceMemory}, V1 V2))
158+
@test i2 * i1 == @constinferred(id(CuVector{T, CUDA.DeviceMemory}, V2 V1))
159+
w = @constinferred(isometry(CuVector{T, CUDA.DeviceMemory}, V1 (oneunit(V1) oneunit(V1)), V1))
160+
@test dim(w) == 2 * dim(V1 V1)
161+
@test w' * w == id(CuVector{T, CUDA.DeviceMemory}, V1)
162+
@test w * w' == (w * w')^2
165163
end
166164
end
167165
@timedtestset "Trivial space insertion and removal" begin
@@ -238,11 +236,11 @@ for V in spacelist
238236
@timedtestset "Tensor conversion" begin # TODO adjoint conversion methods don't work yet
239237
W = V1 V2
240238
t = @constinferred CUDA.randn(W W)
241-
@test typeof(convert(TensorMap, t')) == typeof(t)
239+
#@test typeof(convert(TensorMap, t')) == typeof(t) # TODO Adjoint not supported yet
242240
tc = complex(t)
243241
@test convert(typeof(tc), t) == tc
244242
@test typeof(convert(typeof(tc), t)) == typeof(tc)
245-
@test typeof(convert(typeof(tc), t')) == typeof(tc)
243+
# @test typeof(convert(typeof(tc), t')) == typeof(tc) # TODO Adjoint not supported yet
246244
@test Base.promote_typeof(t, tc) == typeof(tc)
247245
@test Base.promote_typeof(tc, t) == typeof(tc + t)
248246
end
@@ -294,8 +292,10 @@ for V in spacelist
294292
t2 = CUDA.@allowscalar permute(t, (p1, p2))
295293
a2 = convert(Array, collect(t2))
296294
@test a2 permutedims(a, (p1..., p2...))
297-
@test convert(Array, collect(transpose(t2)))
298-
permutedims(a2, (5, 4, 3, 2, 1))
295+
CUDA.@allowscalar begin
296+
@test convert(Array, collect(transpose(t2)))
297+
permutedims(a2, (5, 4, 3, 2, 1))
298+
end
299299
end
300300

301301
t3 = CUDA.@allowscalar repartition(t, k)

0 commit comments

Comments
 (0)