Skip to content

Commit 9a27e19

Browse files
author
Katharine Hyatt
committed
More AMD updates and fixes
1 parent 7fd3318 commit 9a27e19

File tree

4 files changed

+322
-147
lines changed

4 files changed

+322
-147
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
2525
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
2626

2727
[sources]
28-
GPUArrays = {rev = "ksh/more_diag", url = "https://github.com/JuliaGPU/GPUArrays.jl"}
28+
GPUArrays = {rev = "master", url = "https://github.com/JuliaGPU/GPUArrays.jl"}
2929
MatrixAlgebraKit = {rev = "ksh/tk", url = "https://github.com/QuantumKitHub/MatrixAlgebraKit.jl"}
30-
AMDGPU = {rev = "ksh/diag_norm", url = "https://github.com/JuliaGPU/AMDGPU.jl"}
30+
AMDGPU = {rev = "master", url = "https://github.com/JuliaGPU/AMDGPU.jl"}
3131
cuTENSOR = {subdir = "lib/cutensor", url = "https://github.com/JuliaGPU/CUDA.jl"}
3232

3333
[extensions]

ext/TensorKitAMDGPUExt/TensorKitAMDGPUExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import AMDGPU: rand as rocrand, rand! as rocrand!, randn as rocrandn, randn! as
77
using TensorKit
88
import TensorKit.VectorInterface: scalartype as vi_scalartype
99
using TensorKit.Factorizations
10+
using TensorKit.Strided
1011
using TensorKit.Factorizations: AbstractAlgorithm
1112
using TensorKit: SectorDict, tensormaptype, scalar, similarstoragetype, AdjointTensorMap
1213

ext/TensorKitAMDGPUExt/roctensormap.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,3 +224,25 @@ function LinearAlgebra.isposdef(t::ROCTensorMap)
224224
end
225225
return true
226226
end
227+
228+
# Conversion to ROCArray:
229+
#----------------------
230+
# probably not optimized for speed, only for checking purposes
231+
function Base.convert(::Type{ROCArray}, t::AbstractTensorMap)
232+
I = sectortype(t)
233+
if I === Trivial
234+
convert(ROCArray, t[])
235+
else
236+
cod = codomain(t)
237+
dom = domain(t)
238+
T = sectorscalartype(I) <: Complex ? complex(scalartype(t)) :
239+
sectorscalartype(I) <: Integer ? scalartype(t) : float(scalartype(t))
240+
A = AMDGPU.zeros(T, dims(cod)..., dims(dom)...)
241+
for (f₁, f₂) in fusiontrees(t)
242+
F = convert(ROCArray, (f₁, f₂))
243+
Aslice = StridedView(A)[axes(cod, f₁.uncoupled)..., axes(dom, f₂.uncoupled)...]
244+
add!(Aslice, StridedView(TensorKit._kron(convert(ROCArray, t[f₁, f₂]), F)))
245+
end
246+
return A
247+
end
248+
end

0 commit comments

Comments
 (0)