Skip to content

Commit b4659b9

Browse files
authored
Add support for Adapt which enables converting tensor networks to GPU (#187)
* Add support for Adapt which enables converting tensor networks to GPU * Bump to v0.11.13
1 parent 4efceb8 commit b4659b9

File tree

5 files changed

+48
-1
lines changed

5 files changed

+48
-1
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ITensorNetworks"
22
uuid = "2919e153-833c-4bdc-8836-1ea460a35fc7"
33
authors = ["Matthew Fishman <[email protected]>, Joseph Tindall <[email protected]> and contributors"]
4-
version = "0.11.12"
4+
version = "0.11.13"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
@@ -36,19 +36,22 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
3636
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
3737

3838
[weakdeps]
39+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
3940
EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5"
4041
GraphsFlows = "06909019-6f44-4949-96fc-b9d9aaa02889"
4142
OMEinsumContractionOrders = "6f22d1fd-8eed-4bb7-9776-e7d684900715"
4243
Observers = "338f10d5-c7f1-4033-a7d1-f9dec39bcaa0"
4344

4445
[extensions]
46+
ITensorNetworksAdaptExt = "Adapt"
4547
ITensorNetworksEinExprsExt = "EinExprs"
4648
ITensorNetworksGraphsFlowsExt = "GraphsFlows"
4749
ITensorNetworksOMEinsumContractionOrdersExt = "OMEinsumContractionOrders"
4850
ITensorNetworksObserversExt = "Observers"
4951

5052
[compat]
5153
AbstractTrees = "0.4.4"
54+
Adapt = "4"
5255
Combinatorics = "1"
5356
Compat = "3, 4"
5457
DataGraphs = "0.2.3"
@@ -82,6 +85,7 @@ TupleTools = "1.4"
8285
julia = "1.10"
8386

8487
[extras]
88+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
8589
EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5"
8690
GraphsFlows = "06909019-6f44-4949-96fc-b9d9aaa02889"
8791
OMEinsumContractionOrders = "6f22d1fd-8eed-4bb7-9776-e7d684900715"
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
module ITensorNetworksAdaptExt
2+
using Adapt: Adapt, adapt
3+
using ITensorNetworks: AbstractITensorNetwork, map_vertex_data_preserve_graph
4+
function Adapt.adapt_structure(to, tn::AbstractITensorNetwork)
5+
# TODO: Define and use:
6+
#
7+
# @preserve_graph map_vertex_data(adapt(to), tn)
8+
#
9+
# or just:
10+
#
11+
# @preserve_graph map(adapt(to), tn)
12+
return map_vertex_data_preserve_graph(adapt(to), tn)
13+
end
14+
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
3+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
34
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
45
DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a"
56
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"

test/test_ext/Project.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[deps]
2+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
3+
ITensorNetworks = "2919e153-833c-4bdc-8836-1ea460a35fc7"
4+
ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5"
5+
NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19"
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
@eval module $(gensym())
2+
using Adapt: Adapt, adapt
3+
using NamedGraphs.NamedGraphGenerators: named_grid
4+
using ITensorNetworks: random_tensornetwork, siteinds
5+
using ITensors: ITensors
6+
using Test: @test, @testset
7+
8+
struct SinglePrecisionAdaptor end
9+
single_precision(::Type{<:AbstractFloat}) = Float32
10+
single_precision(type::Type{<:Complex}) = complex(single_precision(real(type)))
11+
Adapt.adapt_storage(::SinglePrecisionAdaptor, x) = single_precision(eltype(x)).(x)
12+
13+
@testset "Test ITensorNetworksAdaptExt (eltype=$elt)" for elt in (
14+
Float32, Float64, Complex{Float32}, Complex{Float64}
15+
)
16+
g = named_grid((2, 2))
17+
s = siteinds("S=1/2", g)
18+
tn = random_tensornetwork(elt, s)
19+
@test ITensors.scalartype(tn) === elt
20+
tn′ = adapt(SinglePrecisionAdaptor(), tn)
21+
@test ITensors.scalartype(tn′) === single_precision(elt)
22+
end
23+
end

0 commit comments

Comments
 (0)