Skip to content

Commit 740838b

Browse files
authored
Make it easier to run example 3 on GPU (#22)
* Make it easier to run example 3 on GPU * Bump to v0.2.4
1 parent 3cc10c1 commit 740838b

File tree

3 files changed

+34
-15
lines changed

3 files changed

+34
-15
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ITensorMPS"
22
uuid = "0d1a4710-d33b-49a5-8f18-73bdf49b47e2"
33
authors = ["Matthew Fishman <[email protected]>", "Miles Stoudenmire <[email protected]>"]
4-
version = "0.2.3"
4+
version = "0.2.4"
55

66
[deps]
77
ITensorTDVP = "25707e16-a4db-4a07-99d9-4d67b7af0342"

examples/03_tdvp_time_dependent.jl

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,24 @@ using Random: Random
66
include("03_models.jl")
77
include("03_updaters.jl")
88

9-
function main()
9+
"""
10+
Run the example on CPU:
11+
```julia
12+
main()
13+
```
14+
15+
Run the example on CPU with single precision:
16+
```julia
17+
main(; eltype=Float32)
18+
```
19+
20+
Run the example on GPU:
21+
```julia
22+
using CUDA: cu
23+
main(; eltype=Float32, device=cu)
24+
```
25+
"""
26+
function main(; eltype=Float64, device=identity)
1027
Random.seed!(1234)
1128

1229
# Time dependent Hamiltonian is:
@@ -24,16 +41,16 @@ function main()
2441
outputlevel = 3
2542

2643
# Frequency of time dependent terms
27-
ω₁ = 0.1
28-
ω₂ = 0.2
44+
ω₁ = one(eltype) / 10
45+
ω₂ = one(eltype) / 5
2946

3047
# Nearest and next-nearest neighbor
3148
# Heisenberg couplings.
32-
J₁ = 1.0
33-
J₂ = 1.0
49+
J₁ = one(eltype)
50+
J₂ = one(eltype)
3451

35-
time_step = 0.1
36-
time_stop = 1.0
52+
time_step = one(eltype) / 10
53+
time_stop = one(eltype)
3754

3855
# nsite-update TDVP
3956
nsite = 2
@@ -46,9 +63,9 @@ function main()
4663

4764
# TDVP truncation parameters
4865
maxdim = 100
49-
cutoff = 1e-8
66+
cutoff = (eps(eltype))
5067

51-
tol = 1e-15
68+
tol = 10 * eps(eltype)
5269

5370
@show n
5471
@show ω₁, ω₂
@@ -61,18 +78,20 @@ function main()
6178
f⃗ = map-> (t -> cos* t)), ω⃗)
6279

6380
# H₀ = H(0) = H₁(0) + H₂(0) + …
64-
ℋ₁₀ = heisenberg(n; J=J₁, J2=0.0)
65-
ℋ₂₀ = heisenberg(n; J=0.0, J2=J₂)
81+
ℋ₁₀ = heisenberg(n; J=J₁, J2=zero(eltype))
82+
ℋ₂₀ = heisenberg(n; J=zero(eltype), J2=J₂)
6683
ℋ⃗₀ = (ℋ₁₀, ℋ₂₀)
6784

6885
s = siteinds("S=1/2", n)
6986

70-
H⃗₀ = map(ℋ₀ -> MPO(ℋ₀, s), ℋ⃗₀)
87+
H⃗₀ = map(ℋ₀ -> device(MPO(eltype, ℋ₀, s)), ℋ⃗₀)
7188

7289
# Initial state, ψ₀ = ψ(0)
7390
# Initialize as complex since that is what OrdinaryDiffEq.jl/DifferentialEquations.jl
7491
# expects.
75-
ψ₀ = complex.(random_mps(s, j -> isodd(j) ? "" : ""; linkdims=start_linkdim))
92+
ψ₀ = device(
93+
complex.(random_mps(eltype, s, j -> isodd(j) ? "" : ""; linkdims=start_linkdim))
94+
)
7695

7796
@show norm(ψ₀)
7897

examples/03_updaters.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ function ode_updater(operator, init; internal_kwargs, alg=Tsit5(), kwargs...)
99
time_span = typeof(time_step).((current_time, current_time + time_step))
1010
init_vec, to_itensor = to_vec(init)
1111
f(init::ITensor, p, t) = operator(t)(init)
12-
f(init_vec::Vector, p, t) = to_vec(f(to_itensor(init_vec), p, t))[1]
12+
f(init_vec::AbstractArray, p, t) = to_vec(f(to_itensor(init_vec), p, t))[1]
1313
prob = ODEProblem(f, init_vec, time_span)
1414
sol = solve(prob, alg; kwargs...)
1515
state_vec = sol.u[end]

0 commit comments

Comments
 (0)