Skip to content

Commit 30f9167

Browse files
ytdHuangalbertomercurio
authored andcommitted
fix type conversion of tlist in time evolution
1 parent d7bc352 commit 30f9167

File tree

4 files changed

+13
-3
lines changed

4 files changed

+13
-3
lines changed

src/time_evolution/mcsolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ function mcsolveProblem(
190190
c_ops isa Nothing &&
191191
throw(ArgumentError("The list of collapse operators must be provided. Use sesolveProblem instead."))
192192

193-
t_l = convert(Vector{Float64}, tlist) # Convert it into Float64 to avoid type instabilities for OrdinaryDiffEq.jl
193+
t_l = _convert_tlist(eltype(H), tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl
194194

195195
H_eff = H - 1im * mapreduce(op -> op' * op, +, c_ops) / 2
196196

src/time_evolution/mesolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ function mesolveProblem(
122122
is_time_dependent = !(H_t isa Nothing)
123123
progress_bar_val = makeVal(progress_bar)
124124

125-
t_l = convert(Vector{Float64}, tlist) # Convert it into Float64 to avoid type instabilities for OrdinaryDiffEq.jl
125+
t_l = _convert_tlist(eltype(H), tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl
126126

127127
ρ0 = mat2vec(ket2dm(ψ0).data)
128128

src/time_evolution/sesolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ function sesolveProblem(
103103
is_time_dependent = !(H_t isa Nothing)
104104
progress_bar_val = makeVal(progress_bar)
105105

106-
t_l = convert(Vector{Float64}, tlist) # Convert it into Float64 to avoid type instabilities for OrdinaryDiffEq.jl
106+
t_l = _convert_tlist(eltype(H), tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl
107107

108108
ϕ0 = get_data(ψ0)
109109

src/utilities.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,13 @@ _non_static_array_warning(argname, arg::AbstractVector{T}) where {T} =
6565
@warn "The argument $argname should be a Tuple or a StaticVector for better performance. Try to use `$argname = $(Tuple(arg))` or `$argname = SVector(" *
6666
join(arg, ", ") *
6767
")` instead of `$argname = $arg`." maxlog = 1
68+
69+
# convert tlist in time evolution
70+
_convert_tlist(::Int32, tlist::AbstractVector) = _convert_tlist(Val(32), tlist)
71+
_convert_tlist(::Float32, tlist::AbstractVector) = _convert_tlist(Val(32), tlist)
72+
_convert_tlist(::ComplexF32, tlist::AbstractVector) = _convert_tlist(Val(32), tlist)
73+
_convert_tlist(::Int64, tlist::AbstractVector) = _convert_tlist(Val(64), tlist)
74+
_convert_tlist(::Float64, tlist::AbstractVector) = _convert_tlist(Val(64), tlist)
75+
_convert_tlist(::ComplexF64, tlist::AbstractVector) = _convert_tlist(Val(64), tlist)
76+
_convert_tlist(::Val{32}, tlist::AbstractVector) = convert(Vector{Float32}, tlist)
77+
_convert_tlist(::Val{64}, tlist::AbstractVector) = convert(Vector{Float64}, tlist)

0 commit comments

Comments
 (0)