Skip to content

Commit e165c72

Browse files
committed
Runic again
1 parent a2a3e9a commit e165c72

File tree

3 files changed

+38
-28
lines changed

3 files changed

+38
-28
lines changed

ext/TensorKitCUDAExt/cutensormap.jl

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,10 @@ function CuTensorMap(
7575
) where {S}
7676
return CuTensorMap(data, codom dom)
7777
end
78-
function CuTensorMap(data::DenseVector{T}, V::TensorMapSpace{S, N₁, N₂}) where {T,S,N₁,N₂}
78+
function CuTensorMap(data::DenseVector{T}, V::TensorMapSpace{S, N₁, N₂}) where {T, S, N₁, N₂}
7979
return CuTensorMap{T, S, N₁, N₂}(data, V)
8080
end
81-
function CuTensorMap(data::CuArray{T}, V::TensorMapSpace{S, N₁, N₂}) where {T,S,N₁,N₂}
81+
function CuTensorMap(data::CuArray{T}, V::TensorMapSpace{S, N₁, N₂}) where {T, S, N₁, N₂}
8282
return CuTensorMap{T, S, N₁, N₂}(vec(data), V)
8383
end
8484

@@ -168,44 +168,54 @@ for randfun in (:rand, :randn, :randisometry)
168168
randfun! = Symbol(randfun, :!)
169169
@eval begin
170170
# converting `codomain` and `domain` into `HomSpace`
171-
function $randfun(::Type{A}, codomain::TensorSpace{S},
172-
domain::TensorSpace{S}) where {A <: CuArray, S<:IndexSpace}
171+
function $randfun(
172+
::Type{A}, codomain::TensorSpace{S},
173+
domain::TensorSpace{S}
174+
) where {A <: CuArray, S <: IndexSpace}
173175
return $randfun(A, codomain domain)
174176
end
175-
function $randfun(::Type{T}, ::Type{A}, codomain::TensorSpace{S},
176-
domain::TensorSpace{S}) where {T,S<:IndexSpace, A<:CuArray{T}}
177+
function $randfun(
178+
::Type{T}, ::Type{A}, codomain::TensorSpace{S},
179+
domain::TensorSpace{S}
180+
) where {T, S <: IndexSpace, A <: CuArray{T}}
177181
return $randfun(T, A, codomain domain)
178182
end
179-
function $randfun(rng::Random.AbstractRNG, ::Type{T}, ::Type{A},
180-
codomain::TensorSpace{S},
181-
domain::TensorSpace{S}) where {T,S<:IndexSpace, A<:CuArray{T}}
183+
function $randfun(
184+
rng::Random.AbstractRNG, ::Type{T}, ::Type{A},
185+
codomain::TensorSpace{S},
186+
domain::TensorSpace{S}
187+
) where {T, S <: IndexSpace, A <: CuArray{T}}
182188
return $randfun(rng, T, A, codomain domain)
183189
end
184190

185191
# accepting single `TensorSpace`
186-
$randfun(::Type{A}, codomain::TensorSpace) where {A<:CuArray} = $randfun(A, codomain one(codomain))
187-
function $randfun(::Type{T}, ::Type{A}, codomain::TensorSpace) where {T, A<:CuArray{T}}
192+
$randfun(::Type{A}, codomain::TensorSpace) where {A <: CuArray} = $randfun(A, codomain one(codomain))
193+
function $randfun(::Type{T}, ::Type{A}, codomain::TensorSpace) where {T, A <: CuArray{T}}
188194
return $randfun(T, A, codomain one(codomain))
189195
end
190-
function $randfun(rng::Random.AbstractRNG, ::Type{T},
191-
::Type{A}, codomain::TensorSpace) where {T, A<:CuArray{T}}
196+
function $randfun(
197+
rng::Random.AbstractRNG, ::Type{T},
198+
::Type{A}, codomain::TensorSpace
199+
) where {T, A <: CuArray{T}}
192200
return $randfun(rng, T, A, codomain one(domain))
193201
end
194202

195203
# filling in default eltype
196-
$randfun(::Type{A}, V::TensorMapSpace) where {A<:CuArray} = $randfun(eltype(A), A, V)
197-
function $randfun(rng::Random.AbstractRNG, ::Type{A}, V::TensorMapSpace) where {A<:CuArray}
204+
$randfun(::Type{A}, V::TensorMapSpace) where {A <: CuArray} = $randfun(eltype(A), A, V)
205+
function $randfun(rng::Random.AbstractRNG, ::Type{A}, V::TensorMapSpace) where {A <: CuArray}
198206
return $randfun(rng, eltype(A), A, V)
199207
end
200208

201209
# filling in default rng
202-
function $randfun(::Type{T}, ::Type{A}, V::TensorMapSpace) where {T, A<:CuArray{T}}
210+
function $randfun(::Type{T}, ::Type{A}, V::TensorMapSpace) where {T, A <: CuArray{T}}
203211
return $randfun(Random.default_rng(), T, A, V)
204212
end
205213

206214
# implementation
207-
function $randfun(rng::Random.AbstractRNG, ::Type{T},
208-
::Type{A}, V::TensorMapSpace) where {T, A<:CuArray{T}}
215+
function $randfun(
216+
rng::Random.AbstractRNG, ::Type{T},
217+
::Type{A}, V::TensorMapSpace
218+
) where {T, A <: CuArray{T}}
209219
t = CuTensorMap{T}(undef, V)
210220
$randfun!(rng, t)
211221
return t

src/auxiliary/random.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ randisometry(dims::Base.Dims{2}) = randisometry(Float64, Matrix{Float64}, dims)
1010
function randisometry(::Type{T}, dims::Base.Dims{2}) where {T <: Number}
1111
return randisometry(Random.default_rng(), T, dims)
1212
end
13-
function randisometry(::Type{T}, ::Type{A}, dims::Base.Dims{2}) where {T <: Number, A<:AbstractArray{T}}
13+
function randisometry(::Type{T}, ::Type{A}, dims::Base.Dims{2}) where {T <: Number, A <: AbstractArray{T}}
1414
return randisometry(Random.default_rng(), T, A, dims)
1515
end
16-
function randisometry(rng::Random.AbstractRNG, ::Type{T}, ::Type{A}, dims::Base.Dims{2}) where {T <: Number, A<:AbstractArray{T}}
16+
function randisometry(rng::Random.AbstractRNG, ::Type{T}, ::Type{A}, dims::Base.Dims{2}) where {T <: Number, A <: AbstractArray{T}}
1717
return randisometry!(rng, A(undef, dims))
1818
end
1919

test/cuda/tensors.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -353,14 +353,14 @@ for V in spacelist
353353
H = CUDA.randn(ComplexF64, V2 * V4, V2 * V4)
354354
CUDA.@allowscalar begin
355355
@tensor HrA12[a, s1, s2, c] := rhoL[a, a'] * conj(A1[a', t1, b]) *
356-
A2[b, t2, c'] * rhoR[c', c] *
357-
H[s1, s2, t1, t2]
356+
A2[b, t2, c'] * rhoR[c', c] *
357+
H[s1, s2, t1, t2]
358358

359359
@tensor HrA12array[a, s1, s2, c] := ad(rhoL)[a, a'] *
360-
conj(ad(A1)[a', t1, b]) *
361-
ad(A2)[b, t2, c'] *
362-
ad(rhoR)[c', c] *
363-
ad(H)[s1, s2, t1, t2]
360+
conj(ad(A1)[a', t1, b]) *
361+
ad(A2)[b, t2, c'] *
362+
ad(rhoR)[c', c] *
363+
ad(H)[s1, s2, t1, t2]
364364
end
365365
@test HrA12array ad(HrA12)
366366
end
@@ -484,8 +484,8 @@ for V in spacelist
484484
t = rmul!(t, 1 / norm(S₀, p))
485485
# Probably shouldn't allow truncerr and truncdim, as these require scalar indexing?
486486
CUDA.@allowscalar begin
487-
U, S, V = tsvd(t; trunc=truncbelow(1 / dim(domain(S₀))), p=p)
488-
U′, S′, V′ = tsvd(t; trunc=truncspace(space(S, 1)), p=p)
487+
U, S, V = tsvd(t; trunc = truncbelow(1 / dim(domain(S₀))), p = p)
488+
U′, S′, V′ = tsvd(t; trunc = truncspace(space(S, 1)), p = p)
489489
end
490490
@test (U, S, V) == (U′, S′, V′)
491491
end

0 commit comments

Comments
 (0)