Skip to content

Commit 08db501

Browse files
committed
Output higher dimensional arrays for multi-site operators
1 parent a0aefe6 commit 08db501

File tree

5 files changed

+86
-95
lines changed

5 files changed

+86
-95
lines changed

README.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ julia> Pkg.add("QuantumOperatorDefinitions")
3333

3434
````julia
3535
using QuantumOperatorDefinitions: OpName, SiteType, StateName, , controlled, op, state
36-
using LinearAlgebra: Diagonal
3736
using SparseArrays: SparseMatrixCSC, SparseVector
3837
using Test: @test
3938

@@ -64,8 +63,6 @@ using Test: @test
6463
@test op("Y") == [0 -im; im 0]
6564
@test op("Z") == [1 0; 0 -1]
6665

67-
@test op("Z") isa Diagonal
68-
6966
@test op(Float32, "X") == [0 1; 1 0]
7067
@test eltype(op(Float32, "X")) === Float32
7168
@test op(SparseMatrixCSC, "X") == [0 1; 1 0]

examples/README.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ julia> Pkg.add("QuantumOperatorDefinitions")
3838
# ## Examples
3939

4040
using QuantumOperatorDefinitions: OpName, SiteType, StateName, , controlled, op, state
41-
using LinearAlgebra: Diagonal
4241
using SparseArrays: SparseMatrixCSC, SparseVector
4342
using Test: @test
4443

@@ -69,8 +68,6 @@ using Test: @test
6968
@test op("Y") == [0 -im; im 0]
7069
@test op("Z") == [1 0; 0 -1]
7170

72-
@test op("Z") isa Diagonal
73-
7471
@test op(Float32, "X") == [0 1; 1 0]
7572
@test eltype(op(Float32, "X")) === Float32
7673
@test op(SparseMatrixCSC, "X") == [0 1; 1 0]

src/op.jl

Lines changed: 25 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ function (arrtype::Type{<:AbstractArray})(
5959
)
6060
return arrtype(n, domain)
6161
end
62+
function (arrtype::Type{<:AbstractArray})(n::StateOrOpName, domain::Tuple{Vararg{Integer}})
63+
return arrtype(n, Base.oneto.(domain))
64+
end
6265
(arrtype::Type{<:AbstractArray})(n::StateOrOpName, ts::SiteType...) = arrtype(n, ts)
6366
function (n::StateOrOpName)(domain...)
6467
# TODO: Try one alias at a time?
@@ -89,50 +92,37 @@ function nsites(n::StateOrOpName)
8992
return nsites(n′)
9093
end
9194

95+
# TODO: This does some unwanted conversions, like turning
96+
# `Diagonal` dense.
9297
function array(a::AbstractArray, ax::Tuple{Vararg{AbstractUnitRange}})
9398
return a[ax...]
9499
end
95100

96-
function op_convert(
97-
arrtype::Type{<:AbstractArray{<:Any,N}},
98-
domain::Tuple{Vararg{AbstractUnitRange}},
99-
a::AbstractArray{<:Any,N},
100-
) where {N}
101-
ax = (domain..., domain...)
102-
a′ = array(a, ax)
103-
return convert(arrtype, a′)
104-
end
105-
function op_convert(
106-
arrtype::Type{<:AbstractArray}, domain::Tuple{Vararg{AbstractUnitRange}}, a::AbstractArray
107-
)
108-
ax = (domain..., domain...)
109-
a′ = array(a, ax)
110-
return convert(arrtype, a′)
101+
function state_or_op_axes(::OpName, domain::Tuple{Vararg{AbstractUnitRange}})
102+
return (domain..., domain...)
111103
end
112-
function op_convert(
113-
arrtype::Type{<:AbstractArray{<:Any,N}},
104+
105+
function state_or_op_convert(
106+
n::StateOrOpName,
107+
arrtype::Type{<:AbstractArray},
114108
domain::Tuple{Vararg{AbstractUnitRange}},
115109
a::AbstractArray,
116-
) where {N}
117-
ax = (domain..., domain...)
118-
@assert length(ax) == N
110+
)
111+
ax = state_or_op_axes(n, domain)
119112
a′ = reshape(a, length.(ax))
120113
a′′ = array(a′, ax)
121114
return convert(arrtype, a′′)
122115
end
123-
function (arrtype::Type{<:AbstractArray})(n::OpName, domain::Tuple{Vararg{SiteType}})
116+
117+
function (arrtype::Type{<:AbstractArray})(n::StateOrOpName, domain::Tuple{Vararg{SiteType}})
124118
domain′ = AbstractUnitRange.(domain)
125-
return op_convert(arrtype, domain′, n(domain...))
119+
return state_or_op_convert(n, arrtype, domain′, n(domain...))
126120
end
127-
128121
function (arrtype::Type{<:AbstractArray})(
129-
n::OpName, domain::Tuple{Vararg{AbstractUnitRange}}
122+
n::StateOrOpName, domain::Tuple{Vararg{AbstractUnitRange}}
130123
)
131124
# TODO: Make `(::OpName)(domain...)` constructor process more general inputs.
132-
return op_convert(arrtype, domain, n(Int.(length.(domain))...))
133-
end
134-
function (arrtype::Type{<:AbstractArray})(n::OpName, domain::Tuple{Vararg{Integer}})
135-
return arrtype(n, Base.oneto.(domain))
125+
return state_or_op_convert(n, arrtype, domain, n(Int.(length.(domain))...))
136126
end
137127

138128
function op(arrtype::Type{<:AbstractArray}, n::String, domain...; kwargs...)
@@ -498,10 +488,10 @@ function (n::OpName"Controlled")(domain...)
498488
d_control = prod(to_dim.(domain[1:nc]))
499489
return cat(I(d_control), n.arg(domain[(nc + 1):end]...); dims=(1, 2))
500490
end
501-
@op_alias "CNOT" "Controlled" op = OpName"X"()
502-
@op_alias "CX" "Controlled" op = OpName"X"()
503-
@op_alias "CY" "Controlled" op = OpName"Y"()
504-
@op_alias "CZ" "Controlled" op = OpName"Z"()
491+
@op_alias "CNOT" "Controlled" arg = OpName"X"()
492+
@op_alias "CX" "Controlled" arg = OpName"X"()
493+
@op_alias "CY" "Controlled" arg = OpName"Y"()
494+
@op_alias "CZ" "Controlled" arg = OpName"Z"()
505495
function alias(n::OpName"CPhase")
506496
return controlled(OpName"Phase"(; params(n)...))
507497
end
@@ -524,17 +514,17 @@ function alias(::OpName"CRn")
524514
end
525515
@op_alias "CRn̂" "CRn"
526516

527-
@op_alias "CCNOT" "Controlled" ncontrol = 2 op = OpName"X"()
517+
@op_alias "CCNOT" "Controlled" ncontrol = 2 arg = OpName"X"()
528518
@op_alias "Toffoli" "CCNOT"
529519
@op_alias "CCX" "CCNOT"
530520
@op_alias "TOFF" "CCNOT"
531521

532-
@op_alias "CSWAP" "Controlled" ncontrol = 2 op = OpName"SWAP"()
522+
@op_alias "CSWAP" "Controlled" ncontrol = 2 arg = OpName"SWAP"()
533523
@op_alias "Fredkin" "CSWAP"
534524
@op_alias "CSwap" "CSWAP"
535525
@op_alias "CS" "CSWAP"
536526

537-
@op_alias "CCCNOT" "Controlled" ncontrol = 3 op = OpName"X"()
527+
@op_alias "CCCNOT" "Controlled" ncontrol = 3 arg = OpName"X"()
538528

539529
## # 1-qudit rotation around generic axis n̂.
540530
## # exp(-im * α / 2 * n̂ ⋅ σ⃗)

src/state.jl

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,8 @@ macro state_alias(name1, name2, params...)
3232
return state_alias_expr(name1, name2)
3333
end
3434

35-
function (arrtype::Type{<:AbstractArray})(n::StateName, domain::Tuple{Vararg{SiteType}})
36-
# TODO: Define `state_convert` to handle reshaping multisite states
37-
# to higher order arrays.
38-
return convert(arrtype, n(domain...))
39-
end
40-
function (arrtype::Type{<:AbstractArray})(n::StateName, domain::Tuple{Vararg{Integer}})
41-
# TODO: Define `state_convert` to handle reshaping multisite states
42-
# to higher order arrays.
43-
return convert(arrtype, n(Int.(domain)...))
35+
function state_or_op_axes(::StateName, domain::Tuple{Vararg{AbstractUnitRange}})
36+
return domain
4437
end
4538

4639
# This compiles operator expressions, such as:

test/test_basics.jl

Lines changed: 59 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,30 @@ const elts = (real_elts..., complex_elts...)
3333
[0 -im; im 0],
3434
[1 0; 0 -1],
3535
[0 0; 0 1],
36-
[1 0 0 0; 0 0 1 0; 0 1 0 0; 0 0 0 1],
37-
[1 0 0 0; 0 0 im 0; 0 im 0 0; 0 0 0 1],
38-
(_, θ) -> [
39-
cos/ 2) 0 0 -im*sin/ 2)
40-
0 cos/ 2) -im*sin/ 2) 0
41-
0 -im*sin/ 2) cos/ 2) 0
42-
-im*sin/ 2) 0 0 cos/ 2)
43-
],
44-
(_, θ) -> [
45-
cos/ 2) 0 0 im*sin/ 2)
46-
0 cos/ 2) -im*sin/ 2) 0
47-
0 -im*sin/ 2) cos/ 2) 0
48-
im*sin/ 2) 0 0 cos/ 2)
49-
],
50-
(_, θ) ->
36+
reshape([1 0 0 0; 0 0 1 0; 0 1 0 0; 0 0 0 1], (2, 2, 2, 2)),
37+
reshape([1 0 0 0; 0 0 im 0; 0 im 0 0; 0 0 0 1], (2, 2, 2, 2)),
38+
(_, θ) -> reshape(
39+
[
40+
cos/ 2) 0 0 -im*sin/ 2)
41+
0 cos/ 2) -im*sin/ 2) 0
42+
0 -im*sin/ 2) cos/ 2) 0
43+
-im*sin/ 2) 0 0 cos/ 2)
44+
],
45+
(2, 2, 2, 2),
46+
),
47+
(_, θ) -> reshape(
48+
[
49+
cos/ 2) 0 0 im*sin/ 2)
50+
0 cos/ 2) -im*sin/ 2) 0
51+
0 -im*sin/ 2) cos/ 2) 0
52+
im*sin/ 2) 0 0 cos/ 2)
53+
],
54+
(2, 2, 2, 2),
55+
),
56+
(_, θ) -> reshape(
5157
Diagonal([exp(-im * θ / 2), exp(im * θ / 2), exp(im * θ / 2), exp(-im * θ / 2)]),
58+
(2, 2, 2, 2),
59+
),
5260
[1 0; 0 0],
5361
[0 0; 0 1],
5462
[0 1; 0 0],
@@ -60,31 +68,37 @@ const elts = (real_elts..., complex_elts...)
6068
2 * [0 -im 0; im 0 -im; 0 im 0],
6169
2 * [1 0 0; 0 0 0; 0 0 -1],
6270
[0 0 0; 0 1 0; 0 0 2],
63-
[
64-
1 0 0 0 0 0 0 0 0
65-
0 0 0 1 0 0 0 0 0
66-
0 0 0 0 0 0 1 0 0
67-
0 1 0 0 0 0 0 0 0
68-
0 0 0 0 1 0 0 0 0
69-
0 0 0 0 0 0 0 1 0
70-
0 0 1 0 0 0 0 0 0
71-
0 0 0 0 0 1 0 0 0
72-
0 0 0 0 0 0 0 0 1
73-
],
74-
[
75-
1 0 0 0 0 0 0 0 0
76-
0 0 0 im 0 0 0 0 0
77-
0 0 0 0 0 0 im 0 0
78-
0 im 0 0 0 0 0 0 0
79-
0 0 0 0 1 0 0 0 0
80-
0 0 0 0 0 0 0 im 0
81-
0 0 im 0 0 0 0 0 0
82-
0 0 0 0 0 im 0 0 0
83-
0 0 0 0 0 0 0 0 1
84-
],
85-
(O, θ) -> exp(-im */ 2) * kron(O, O)),
86-
(O, θ) -> exp(-im */ 2) * kron(O, O)),
87-
(O, θ) -> exp(-im */ 2) * kron(O, O)),
71+
reshape(
72+
[
73+
1 0 0 0 0 0 0 0 0
74+
0 0 0 1 0 0 0 0 0
75+
0 0 0 0 0 0 1 0 0
76+
0 1 0 0 0 0 0 0 0
77+
0 0 0 0 1 0 0 0 0
78+
0 0 0 0 0 0 0 1 0
79+
0 0 1 0 0 0 0 0 0
80+
0 0 0 0 0 1 0 0 0
81+
0 0 0 0 0 0 0 0 1
82+
],
83+
(3, 3, 3, 3),
84+
),
85+
reshape(
86+
[
87+
1 0 0 0 0 0 0 0 0
88+
0 0 0 im 0 0 0 0 0
89+
0 0 0 0 0 0 im 0 0
90+
0 im 0 0 0 0 0 0 0
91+
0 0 0 0 1 0 0 0 0
92+
0 0 0 0 0 0 0 im 0
93+
0 0 im 0 0 0 0 0 0
94+
0 0 0 0 0 im 0 0 0
95+
0 0 0 0 0 0 0 0 1
96+
],
97+
(3, 3, 3, 3),
98+
),
99+
(O, θ) -> reshape(exp(-im */ 2) * kron(O, O)), (3, 3, 3, 3)),
100+
(O, θ) -> reshape(exp(-im */ 2) * kron(O, O)), (3, 3, 3, 3)),
101+
(O, θ) -> reshape(exp(-im */ 2) * kron(O, O)), (3, 3, 3, 3)),
88102
[1 0 0; 0 0 0; 0 0 0],
89103
[0 0 0; 0 1 0; 0 0 0],
90104
[0 1 0; 0 0 0; 0 0 0],
@@ -114,9 +128,9 @@ const elts = (real_elts..., complex_elts...)
114128
(OpName("Ry"; θ=π / 3), 1, complex_elts, exp(-im * π / 6 * Ymat)),
115129
(OpName("Rz"; θ=π / 3), 1, complex_elts, exp(-im * π / 6 * Zmat)),
116130
(OpName("SWAP"), 2, elts, SWAPmat),
117-
(OpName("√SWAP"), 2, complex_elts, SWAPmat),
131+
# (OpName("√SWAP"), 2, complex_elts, √SWAPmat),
118132
(OpName("iSWAP"), 2, complex_elts, iSWAPmat),
119-
(OpName("√iSWAP"), 2, complex_elts, iSWAPmat),
133+
# (OpName("√iSWAP"), 2, complex_elts, √iSWAPmat),
120134
(OpName("Rxx"; θ=π / 3), 2, complex_elts, RXXmat(Xmat, π / 3)),
121135
(OpName("RXX"; θ=π / 3), 2, complex_elts, RXXmat(Xmat, π / 3)),
122136
(OpName("Ryy"; θ=π / 3), 2, complex_elts, RYYmat(Ymat, π / 3)),
@@ -128,7 +142,7 @@ const elts = (real_elts..., complex_elts...)
128142
(OpName("StandardBasis"; index=(1, 2)), 1, elts, StandardBasis12mat),
129143
)
130144
@test nsites(o) == nbits
131-
for arraytype in (AbstractArray, AbstractMatrix, Array, Matrix)
145+
for arraytype in (AbstractArray, Array)
132146
for elt in elts
133147
ts = ntuple(Returns(t), nbits)
134148
lens = ntuple(Returns(len), nbits)
@@ -149,7 +163,7 @@ const elts = (real_elts..., complex_elts...)
149163
@test op("X * Y + 2 * Z") == op("X") * op("Y") + 2 * op("Z")
150164
@test op("exp(im * (X * Y + 2 * Z))") == exp(im * (op("X") * op("Y") + 2 * op("Z")))
151165
@test op("exp(im * (X ⊗ Y + Z ⊗ Z))") ==
152-
exp(im * (kron(op("X"), op("Y")) + kron(op("Z"), op("Z"))))
166+
reshape(exp(im * (kron(op("X"), op("Y")) + kron(op("Z"), op("Z")))), (2, 2, 2, 2))
153167
@test op("Ry{θ=π/2}") == op("Ry"; θ=π / 2)
154168
# Awkward parsing corner cases.
155169
@test op("S+") == Matrix(OpName("S+"))
@@ -187,7 +201,7 @@ const elts = (real_elts..., complex_elts...)
187201
@test state("2", 3) == [0, 0, 1]
188202

189203
@test state("|0⟩ + 2|+⟩") == state("0") + 2 * state("+")
190-
@test state("|0⟩ ⊗ |+⟩") == kron(state("0"), state("+"))
204+
@test state("|0⟩ ⊗ |+⟩") == reshape(kron(state("0"), state("+")), (2, 2))
191205
end
192206
@testset "Electron/tJ" begin
193207
for (ns, x) in (

0 commit comments

Comments
 (0)