Skip to content

Commit b324a8c

Browse files
authored
Fixes and tests for CuStateVec (#2664)
1 parent ffd75e8 commit b324a8c

File tree

4 files changed

+51
-38
lines changed

4 files changed

+51
-38
lines changed

lib/custatevec/src/libcustatevec.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ end
600600
initialize_context()
601601
@gcsafe_ccall libcustatevec.custatevecAccessorGet(handle::custatevecHandle_t,
602602
accessor::custatevecAccessorDescriptor_t,
603-
externalBuffer::Ptr{Cvoid},
603+
externalBuffer::PtrOrCuPtr{Cvoid},
604604
_begin::custatevecIndex_t,
605605
_end::custatevecIndex_t)::custatevecStatus_t
606606
end
@@ -609,7 +609,7 @@ end
609609
initialize_context()
610610
@gcsafe_ccall libcustatevec.custatevecAccessorSet(handle::custatevecHandle_t,
611611
accessor::custatevecAccessorDescriptor_t,
612-
externalBuffer::Ptr{Cvoid},
612+
externalBuffer::PtrOrCuPtr{Cvoid},
613613
_begin::custatevecIndex_t,
614614
_end::custatevecIndex_t)::custatevecStatus_t
615615
end
@@ -665,7 +665,7 @@ end
665665
svType)
666666
initialize_context()
667667
@gcsafe_ccall libcustatevec.custatevecInitializeStateVector(handle::custatevecHandle_t,
668-
sv::Ptr{Cvoid},
668+
sv::PtrOrCuPtr{Cvoid},
669669
svDataType::cudaDataType_t,
670670
nIndexBits::UInt32,
671671
svType::custatevecStateVectorType_t)::custatevecStatus_t

lib/custatevec/src/statevec.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
function initialize!(sv::CuStateVec, sv_type::custatevecStateVectorType_t)
2+
custatevecInitializeStateVector(handle(), sv.data, eltype(sv), sv.nbits, sv_type)
3+
sv
4+
end
5+
16
function applyPauliExp!(sv::CuStateVec, theta::Float64, paulis::Vector{<:Pauli}, targets::Vector{Int32}, controls::Vector{Int32}, controlValues::Vector{Int32}=fill(one(Int32), length(controls)))
27
cupaulis = CuStateVecPauli.(paulis)
38
custatevecApplyPauliRotation(handle(), sv.data, eltype(sv), sv.nbits, theta, cupaulis, targets, length(targets), controls, controlValues, length(controls))
@@ -178,10 +183,11 @@ function testMatrixType(matrix::Union{Matrix, CuMatrix}, adjoint::Bool, matrix_t
178183
return residualNorm[]
179184
end
180185

181-
function accessorSet(a::CuStateVecAccessor, external_buf::Union{Vector, CuVector}, i_begin::Int, i_end::Int)
182-
custatevecAccessorSet(handle(), a, external_buf, i_begin, i_end)
186+
# TODO attach this to the Julia indexing API
187+
function accessorSet!(a::CuStateVecAccessor, external_buf::Union{Vector, CuVector}, i_begin::Int, i_end::Int)
188+
custatevecAccessorSet(handle(), a, pointer(external_buf), i_begin, i_end)
183189
end
184190

185191
function accessorGet(a::CuStateVecAccessor, external_buf::Union{Vector, CuVector}, i_begin::Int, i_end::Int)
186-
custatevecAccessorGet(handle(), a, external_buf, i_begin, i_end)
192+
custatevecAccessorGet(handle(), a, pointer(external_buf), i_begin, i_end)
187193
end

lib/custatevec/src/types.jl

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,40 +3,20 @@
33
## custatevec compute type
44

55
function Base.convert(::Type{custatevecComputeType_t}, T::DataType)
6-
if T == Float16
7-
return CUSTATEVEC_COMPUTE_16F
8-
elseif T == Float32
6+
if T == Float32
97
return CUSTATEVEC_COMPUTE_32F
108
elseif T == Float64
119
return CUSTATEVEC_COMPUTE_64F
12-
elseif T == UInt8
13-
return CUSTATEVEC_COMPUTE_8U
14-
elseif T == Int8
15-
return CUSTATEVEC_COMPUTE_8I
16-
elseif T == UInt32
17-
return CUSTATEVEC_COMPUTE_32U
18-
elseif T == Int32
19-
return CUSTATEVEC_COMPUTE_32I
2010
else
2111
throw(ArgumentError("cuStateVec type equivalent for compute type $T does not exist!"))
2212
end
2313
end
2414

2515
function Base.convert(::Type{Type}, T::custatevecComputeType_t)
26-
if T == CUSTATEVEC_COMPUTE_16F
27-
return Float16
28-
elseif T == CUSTATEVEC_COMPUTE_32F
16+
if T == CUSTATEVEC_COMPUTE_32F || T == CUSTATEVEC_COMPUTE_TF32
2917
return Float32
3018
elseif T == CUSTATEVEC_COMPUTE_64F
3119
return Float64
32-
elseif T == CUSTATEVEC_COMPUTE_8U
33-
return UInt8
34-
elseif T == CUSTATEVEC_COMPUTE_32U
35-
return UInt32
36-
elseif T == CUSTATEVEC_COMPUTE_8I
37-
return Int8
38-
elseif T == CUSTATEVEC_COMPUTE_32I
39-
return Int32
4020
else
4121
throw(ArgumentError("Julia type equivalent for compute type $T does not exist!"))
4222
end
@@ -45,9 +25,7 @@ end
4525
function compute_type(sv_type::DataType, mat_type::DataType)
4626
if sv_type == ComplexF64 && mat_type == ComplexF64
4727
return Float64
48-
elseif sv_type == ComplexF32 && mat_type == ComplexF64
49-
return Float32
50-
elseif sv_type == ComplexF32 && mat_type == ComplexF32
28+
elseif sv_type == ComplexF32 && mat_type <: Union{ComplexF64, ComplexF32}
5129
return Float32
5230
end
5331
end
@@ -67,13 +45,14 @@ mutable struct CuStateVec{T}
6745
data::CuVector{T}
6846
nbits::UInt32
6947
end
70-
function CuStateVec(T, n_qubits::Int)
48+
function CuStateVec(T, n_qubits::Int; sv_type::custatevecStateVectorType_t=CUSTATEVEC_STATE_VECTOR_TYPE_ZERO)
7149
data = CUDA.zeros(T, 2^n_qubits)
7250
# in most cases, taking the hit here for setting one element
7351
# is cheaper than building the entire thing on the CPU and
7452
# copying it over
75-
CUDA.@allowscalar data[1] = one(T)
76-
CuStateVec{T}(data, n_qubits)
53+
sv = CuStateVec{T}(data, n_qubits)
54+
initialize!(sv, sv_type)
55+
return sv
7756
end
7857
CuStateVec(v::CuVector{T}) where {T} = CuStateVec{T}(v, UInt32(log2(length(v))))
7958
CuStateVec(v::Vector{T}) where {T} = CuStateVec(CuVector{T}(v))
@@ -102,9 +81,13 @@ mutable struct CuStateVecAccessor
10281
function CuStateVecAccessor(sv::CuStateVec, bit_ordering::Vector{Int}, mask_bit_string::Vector{Int}, mask_ordering::Vector{Int})
10382
desc_ref = Ref{custatevecAccessorDescriptor_t}()
10483
extra_size = Ref{Csize_t}(0)
105-
custatevecAccessorCreate(handle(), pointer(sv.data), eltype(sv), sv.nbits, desc_ref, bit_ordering, length(bit_ordering), mask_bit_string, mark_ordering, length(mask_bit_string), extra_size)
84+
mask_string = isempty(mask_bit_string) ? C_NULL : mask_bit_string
85+
mask_order = isempty(mask_ordering) ? C_NULL : mask_ordering
86+
custatevecAccessorCreate(handle(), pointer(sv.data), eltype(sv), sv.nbits, desc_ref, bit_ordering, length(bit_ordering), mask_string, mask_order, length(mask_bit_string), extra_size)
10687
obj = new(desc_ref[], extra_size[])
10788
finalizer(custatevecAccessorDestroy, obj)
10889
obj
10990
end
11091
end
92+
93+
Base.unsafe_convert(::Type{custatevecAccessorDescriptor_t}, desc::CuStateVecAccessor) = desc.handle

lib/custatevec/test/runtests.jl

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using cuStateVec
88
@info "cuStateVec version: $(cuStateVec.version())"
99

1010
@testset "cuStateVec" begin
11-
import cuStateVec: CuStateVec, applyMatrix!, applyMatrixBatched!, applyPauliExp!, applyGeneralizedPermutationMatrix!, expectation, expectationsOnPauliBasis, sample, testMatrixType, Pauli, PauliX, PauliY, PauliZ, PauliI, measureOnZBasis!, swapIndexBits!, abs2SumOnZBasis, collapseOnZBasis!, batchMeasure!, abs2SumArray, collapseByBitString!, abs2SumArrayBatched, collapseByBitStringBatched!
11+
import cuStateVec: CuStateVec, applyMatrix!, applyMatrixBatched!, applyPauliExp!, applyGeneralizedPermutationMatrix!, expectation, expectationsOnPauliBasis, sample, testMatrixType, Pauli, PauliX, PauliY, PauliZ, PauliI, measureOnZBasis!, swapIndexBits!, abs2SumOnZBasis, collapseOnZBasis!, batchMeasure!, abs2SumArray, collapseByBitString!, abs2SumArrayBatched, collapseByBitStringBatched!, accessorSet!, accessorGet, CuStateVecAccessor
1212

1313
@testset "applyMatrix! and expectation" begin
1414
# build a simple state and compute expectations
@@ -43,15 +43,23 @@ using cuStateVec
4343
n_q = 2
4444
@testset for elty in [ComplexF32, ComplexF64]
4545
H = convert(Matrix{elty}, (1/√2).*[1 1; 1 -1])
46-
X = convert(Matrix{elty}, [0 1; 1 0])
47-
Z = convert(Matrix{elty}, [1 0; 0 -1])
4846
sv = CuStateVec(elty, n_q)
4947
sv = applyMatrix!(sv, H, false, Int32[0], Int32[])
5048
sv = applyMatrix!(sv, H, false, Int32[1], Int32[])
5149
pauli_ops = [cuStateVec.Pauli[cuStateVec.PauliX()], cuStateVec.Pauli[cuStateVec.PauliX()]]
5250
exp_vals = expectationsOnPauliBasis(sv, pauli_ops, [[0], [1]])
5351
@test exp_vals[1] 1.0 atol=1e-6
5452
@test exp_vals[2] 1.0 atol=1e-6
53+
54+
55+
H = convert(Matrix{elty}, (1/√2).*[1 1; 1 -1])
56+
sv = CuStateVec(elty, n_q)
57+
sv = applyMatrix!(sv, H, false, Int32[0], Int32[])
58+
sv = applyMatrix!(sv, H, false, Int32[1], Int32[])
59+
pauli_ops = [cuStateVec.Pauli[cuStateVec.PauliY()], cuStateVec.Pauli[cuStateVec.PauliI()]]
60+
exp_vals = expectationsOnPauliBasis(sv, pauli_ops, [[0], [1]])
61+
@test exp_vals[1] 0.0 atol=1e-6
62+
@test exp_vals[2] 1.0 atol=1e-6
5563
end
5664
end
5765
@testset "applyMatrixBatched! and expectation" begin
@@ -248,4 +256,20 @@ using cuStateVec
248256
@test testMatrixType(CuMatrix{elty}(A), true, cuStateVec.CUSTATEVEC_MATRIX_TYPE_UNITARY) <= 200 * eps(real(elty))
249257
end
250258
end
259+
@testset "accessorSet!/accessorGet" begin
260+
nIndexBits = 3
261+
bitOrdering = [1, 2, 0]
262+
@testset for elty in [ComplexF32, ComplexF64]
263+
h_sv = zeros(elty, 2^nIndexBits)
264+
h_sv_result = elty[0; 0.1im; 0.1+0.1im; 0.1+0.2im; 0.2+0.2im; 0.3+0.3im; 0.3+0.4im; 0.4+0.5im]
265+
buffer = elty[0; 0.1im; 0.1+0.1im; 0.1+0.2im; 0.2+0.2im; 0.3+0.3im; 0.3+0.4im; 0.4+0.5im]
266+
267+
sv = CuStateVec(h_sv)
268+
acc = CuStateVecAccessor(sv, bitOrdering, Int[], Int[])
269+
accessorSet!(acc, buffer, 0, 2^nIndexBits)
270+
next_buf = similar(buffer)
271+
accessorGet(acc, next_buf, 0, 2^nIndexBits)
272+
@test next_buf == h_sv_result
273+
end
274+
end
251275
end

0 commit comments

Comments
 (0)