Skip to content

Commit 18524a6

Browse files
ytdHuangalbertomercurio
authored andcommitted
introduce inner function _convert_u0
1 parent 99dcb09 commit 18524a6

File tree

5 files changed

+11
-14
lines changed

5 files changed

+11
-14
lines changed

ext/QuantumToolboxCUDAExt.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module QuantumToolboxCUDAExt
22

33
using QuantumToolbox
4+
import QuantumToolbox: _convert_u0
45
import CUDA: cu, CuArray
56
import CUDA.CUSPARSE: CuSparseVector, CuSparseMatrixCSC, CuSparseMatrixCSR
67
import SparseArrays: SparseVector, SparseMatrixCSC
@@ -89,4 +90,7 @@ _change_eltype(::Type{T}, ::Val{32}) where {T<:AbstractFloat} = Float32
8990
_change_eltype(::Type{Complex{T}}, ::Val{64}) where {T<:Union{Int,AbstractFloat}} = ComplexF64
9091
_change_eltype(::Type{Complex{T}}, ::Val{32}) where {T<:Union{Int,AbstractFloat}} = ComplexF32
9192

93+
# make sure u0 in time evolution is dense vector and has complex element type
94+
_convert_u0(u0::Union{CuArray{T},CuSparseVector{T}}) where {T<:Number} = convert(CuArray{complex(T)}, u0)
95+
9296
end

src/time_evolution/mcsolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ function mcsolveProblem(
190190
c_ops isa Nothing &&
191191
throw(ArgumentError("The list of collapse operators must be provided. Use sesolveProblem instead."))
192192

193-
t_l = _convert_tlist(eltype(H), tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl
193+
t_l = convert(Vector{real(eltype(ψ0))}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl
194194

195195
H_eff = H - 1im * mapreduce(op -> op' * op, +, c_ops) / 2
196196

src/time_evolution/mesolve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,9 @@ function mesolveProblem(
122122
is_time_dependent = !(H_t isa Nothing)
123123
progress_bar_val = makeVal(progress_bar)
124124

125-
t_l = _convert_tlist(eltype(H), tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl
125+
ρ0 = _convert_u0(mat2vec(ket2dm(ψ0).data))
126126

127-
ρ0 = mat2vec(ket2dm(ψ0).data)
127+
t_l = convert(Vector{real(eltype(ρ0))}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl
128128

129129
L = liouvillian(H, c_ops).data
130130
progr = ProgressBar(length(t_l), enable = getVal(progress_bar_val))

src/time_evolution/sesolve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,9 @@ function sesolveProblem(
103103
is_time_dependent = !(H_t isa Nothing)
104104
progress_bar_val = makeVal(progress_bar)
105105

106-
t_l = _convert_tlist(eltype(H), tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl
106+
ϕ0 = _convert_u0(get_data(ψ0))
107107

108-
ϕ0 = get_data(ψ0)
108+
t_l = convert(Vector{real(eltype(ϕ0))}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl
109109

110110
U = -1im * get_data(H)
111111
progr = ProgressBar(length(t_l), enable = getVal(progress_bar_val))

src/utilities.jl

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,5 @@ _non_static_array_warning(argname, arg::AbstractVector{T}) where {T} =
6666
join(arg, ", ") *
6767
")` instead of `$argname = $arg`." maxlog = 1
6868

69-
# convert tlist in time evolution
70-
_convert_tlist(::Type{Int32}, tlist::AbstractVector) = _convert_tlist(Val(32), tlist)
71-
_convert_tlist(::Type{Float32}, tlist::AbstractVector) = _convert_tlist(Val(32), tlist)
72-
_convert_tlist(::Type{ComplexF32}, tlist::AbstractVector) = _convert_tlist(Val(32), tlist)
73-
_convert_tlist(::Type{Int64}, tlist::AbstractVector) = _convert_tlist(Val(64), tlist)
74-
_convert_tlist(::Type{Float64}, tlist::AbstractVector) = _convert_tlist(Val(64), tlist)
75-
_convert_tlist(::Type{ComplexF64}, tlist::AbstractVector) = _convert_tlist(Val(64), tlist)
76-
_convert_tlist(::Val{32}, tlist::AbstractVector) = convert(Vector{Float32}, tlist)
77-
_convert_tlist(::Val{64}, tlist::AbstractVector) = convert(Vector{Float64}, tlist)
69+
# make sure u0 in time evolution is dense vector and has complex element type
70+
_convert_u0(u0::AbstractVector{T}) where {T<:Number} = convert(Vector{complex(T)}, u0)

0 commit comments

Comments
 (0)