Skip to content

Commit 1cb66b8

Browse files
authored
Adapt extension (#344)
* add Adapt * convert to extension * add Adapt implementations * add Adapt tests * fix tests
1 parent dfe4026 commit 1cb66b8

File tree

4 files changed

+35
-1
lines changed

4 files changed

+35
-1
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@ TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
1818
VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
1919

2020
[weakdeps]
21+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
2122
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2223
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2324
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
2425
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
2526

2627
[extensions]
28+
TensorKitAdaptExt = "Adapt"
2729
TensorKitCUDAExt = ["CUDA", "cuTENSOR"]
2830
TensorKitChainRulesCoreExt = "ChainRulesCore"
2931
TensorKitFiniteDifferencesExt = "FiniteDifferences"

ext/TensorKitAdaptExt.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
module TensorKitAdaptExt
2+
3+
using TensorKit
4+
using TensorKit: AdjointTensorMap
5+
using Adapt
6+
7+
function Adapt.adapt_structure(to, x::TensorMap)
8+
data′ = adapt(to, x.data)
9+
return TensorMap{eltype(data′)}(data′, space(x))
10+
end
11+
function Adapt.adapt_structure(to, x::AdjointTensorMap)
12+
return adjoint(adapt(to, parent(x)))
13+
end
14+
function Adapt.adapt_structure(to, x::DiagonalTensorMap)
15+
data′ = adapt(to, x.data)
16+
return DiagonalTensorMap(data′, x.domain)
17+
end
18+
19+
end

src/TensorKit.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,6 @@ using LinearAlgebra: norm, dot, normalize, normalize!, tr,
142142
eigen, eigen!, svd, svd!,
143143
isposdef, isposdef!, rank, cond,
144144
Diagonal, Hermitian
145-
using MatrixAlgebraKit
146145

147146
import Base.Meta
148147

test/cuda/tensors.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,20 @@ for V in spacelist
115115
@test domain(t2) == one(W)
116116
end
117117
end
118+
@timedtestset "Adapt" begin
119+
W = V1 ⊗ V2 ⊗ V3 ⊗ V4 ⊗ V5
120+
for T in (Int, Float32, ComplexF64)
121+
t = rand(T, W)
122+
t_gpu = @constinferred adapt(CuArray, t)
123+
@test storagetype(t_gpu) <: CuArray{T}
124+
@test scalartype(t_gpu) === scalartype(t)
125+
@test collect(t_gpu.data) == t.data
126+
127+
t_cpu = @constinferred adapt(Array, t_gpu)
128+
@test t_cpu == t
129+
@test storagetype(t_cpu) <: Array{T}
130+
end
131+
end
118132
@timedtestset "Tensor Dict conversion" begin
119133
W = V1 ⊗ V2 ⊗ V3 ← V4 ⊗ V5
120134
for T in (Int, Float32, ComplexF64)

0 commit comments

Comments
 (0)