Skip to content

Commit 1f35b03

Browse files
YanickKiYanick Kind
andauthored
Add gradient of textbook fidelity and a test with FD (#573)
* Add gradient of textbook fidelity and a test with FD * Fix preexisting tests specializes.jl * Add fidelity2 and fidelity2' Implement a squared version of the fidelity. This improves numerical stability for algorithms that intentionally minimize the overlap, since, in contrast to the existing gradient fidelity', no division by the overlap occurs in the gradient. Additionally: -Implement tests for fidelity2 and fidelity2' -Add a docstring for fidelity2 and fidelity2' and integrate them into the interface -Add the internal function pure_state_fidelity2 which uses abs2 instead of abs The docstrings are concise but reference the fidelity docstring and clarify the differences. * Fix docstring and clean up tests * Remove unused circuits in tests * Let fidelity2 call into fidelity --------- Co-authored-by: Yanick Kind <[email protected]>
1 parent 0979147 commit 1f35b03

File tree

8 files changed

+175
-3
lines changed

8 files changed

+175
-3
lines changed

lib/YaoAPI/src/registers.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,17 @@ julia> fidelity(reg1, reg2)
652652
"""
653653
@interface fidelity
654654

655+
656+
"""
657+
fidelity2(register1, register2) -> Real/Vector{<:Real}
658+
fidelity2'(pair_or_reg1, pair_or_reg2) -> (g1, g2)
659+
660+
Return the [`fidelity`](@ref) squared.
661+
662+
`fidelity2'` returns the corresponding gradient of `fidelity2` with respect to registers and circuit parameters, similar to `fidelity'`.
663+
"""
664+
@interface fidelity2
665+
655666
"""
656667
tracedist(register1, register2)
657668

lib/YaoArrayRegister/src/YaoArrayRegister.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ export AbstractRegister,
7676
von_neumann_entropy,
7777
mutual_information,
7878
fidelity,
79+
fidelity2,
7980
focus!,
8081
focus,
8182
insert_qudits!,

lib/YaoArrayRegister/src/density_matrix.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ function YaoAPI.fidelity(m::DensityMatrix, n::DensityMatrix)
103103
return density_matrix_fidelity(m.state, n.state)
104104
end
105105

106+
YaoAPI.fidelity2(m::DensityMatrix, n::DensityMatrix) = YaoAPI.fidelity(m, n)^2
107+
106108
function YaoAPI.purify(r::DensityMatrix{D}; num_env::Int = nactive(r)) where {D}
107109
Ne = D ^ num_env
108110
Ns = size(r.state, 1)

lib/YaoArrayRegister/src/register.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,8 @@ function YaoAPI.fidelity(r1::ArrayReg, r2::ArrayReg)
370370
end
371371
end
372372

373+
YaoAPI.fidelity2(r1::AbstractArrayReg, r2::AbstractArrayReg) = YaoAPI.fidelity(r2, r1).^2
374+
373375
YaoAPI.tracedist(r1::ArrayReg, r2::ArrayReg) = tracedist(density_matrix(r1), density_matrix(r2))
374376
YaoAPI.tracedist(r1::BatchedArrayReg, r2::BatchedArrayReg) = tracedist.(r1, r2)
375377

lib/YaoArrayRegister/test/density_matrix.jl

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,49 @@ using LinearAlgebra
3737
) |> all
3838
end
3939

40+
@testset "test fidelity2" begin
41+
reg = rand_state(3)
42+
reg_ = rand_state(3)
43+
reg2 = clone(reg, 3)
44+
@test fidelity2(reg, reg) 1
45+
@test fidelity2(reg, reg_) < 1
46+
@test fidelity2(reg2, reg2) [1, 1, 1]
47+
@test isapprox(fidelity2(reg, reg_), fidelity(reg, reg_)^2)
48+
# mix
49+
reg4 = join(reg, reg)
50+
reg5 = join(reg_, reg_)
51+
focus!(reg4, 1:3)
52+
focus!(reg5, 1:3)
53+
@test isapprox(fidelity2(reg, reg_), fidelity2(reg4, reg5), atol = 1e-5)
54+
@test isapprox(fidelity2(reg4, reg5), fidelity(reg4, reg5)^2)
55+
56+
@test isapprox.(
57+
fidelity2(reg, reg_),
58+
fidelity2(clone(reg4, 3), clone(reg5, 3)),
59+
atol = 1e-5,
60+
) |> all
61+
62+
@test isapprox.(
63+
fidelity2(clone(reg4, 3), clone(reg5, 3)),
64+
fidelity(clone(reg4, 3), clone(reg5, 3)).^2,
65+
) |> all
66+
67+
# batch
68+
st = rand(ComplexF64, 8, 2)
69+
reg1 = BatchedArrayReg(st)
70+
reg2 = rand_state(3)
71+
72+
@test fidelity2(reg1, reg2)
73+
[fidelity2(ArrayReg(st[:, 1]), reg2), fidelity2(ArrayReg(st[:, 2]), reg2)]
74+
@test isapprox(fidelity2(reg1, reg2), fidelity(reg1, reg2).^2)
75+
76+
@test isapprox.(
77+
fidelity2(reg, reg_),
78+
fidelity2(clone(reg4, 3), clone(reg5, 3)),
79+
atol = 1e-5,
80+
) |> all
81+
end
82+
4083
@testset "test trace distance" begin
4184
reg = rand_state(3)
4285
reg_ = rand_state(3)
@@ -162,15 +205,18 @@ end
162205
r2 = density_matrix(reg2, (2, 1))
163206
expected = abs(tr(sqrt(sqrt(r1.state) * r2.state * sqrt(r1.state))))
164207
@test expected fidelity(r1, r2) atol=1e-5
165-
208+
@test expected^2 fidelity2(r1, r2) atol=1e-5
166209
# focused state is viewed as mixed state
167210
f1 = focus!(copy(reg1), (2, 1))
168211
f2 = focus!(copy(reg2), (2, 1))
169212
@test fidelity(r1, r2) fidelity(f1, f2) atol=1e-6
170-
213+
@test fidelity2(r1, r2) fidelity2(f1, f2) atol=1e-6
214+
@test fidelity2(r1, r2) fidelity(f1, f2)^2 atol=1e-6
171215
# fidelity between focused and pure state
172216
f1 = rand_state(2)
173217
@test fidelity(density_matrix(f1, (1, 2)), r2) fidelity(f1, f2) atol=1e-6
218+
@test fidelity2(density_matrix(f1, (1, 2)), r2) fidelity2(f1, f2) atol=1e-6
219+
@test fidelity2(density_matrix(f1, (1, 2)), r2) fidelity(f1, f2)^2 atol=1e-6
174220

175221
dm = rand_density_matrix(2)
176222
@test is_density_matrix(dm.state)

lib/YaoBlocks/src/YaoBlocks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import YaoAPI:
2828
dispatch!,
2929
expect,
3030
fidelity,
31+
fidelity2,
3132
focus!,
3233
getiparams,
3334
iparams_eltype,

lib/YaoBlocks/src/autodiff/specializes.jl

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
for F in [:expect, :fidelity, :operator_fidelity]
1+
for F in [:expect, :fidelity, :fidelity2, :operator_fidelity]
22
@eval Base.adjoint(::typeof($F)) = Adjoint($F)
33
@eval Base.show(io::IO, ::Adjoint{Any,typeof($F)}) = print(io, "$($F)'")
44
@eval Base.show(io::IO, ::MIME"text/plain", ::Adjoint{Any,typeof($F)}) =
@@ -24,6 +24,7 @@ end
2424
_eval(p::Pair{<:AbstractRegister,<:AbstractBlock}) = copy(p.first) |> p.second
2525
_eval(reg::AbstractRegister) = reg
2626
YaoAPI.fidelity(p1, p2) = fidelity(_eval(p1), _eval(p2))
27+
YaoAPI.fidelity2(p1, p2) = fidelity2(_eval(p1), _eval(p2))
2728

2829
function (::Adjoint{Any,typeof(fidelity)})(
2930
reg1::Union{AbstractArrayReg,Pair{<:AbstractArrayReg,<:AbstractBlock}},
@@ -32,6 +33,13 @@ function (::Adjoint{Any,typeof(fidelity)})(
3233
fidelity_g(reg1, reg2)
3334
end
3435

36+
function (::Adjoint{Any,typeof(fidelity2)})(
37+
reg1::Union{AbstractArrayReg,Pair{<:AbstractArrayReg,<:AbstractBlock}},
38+
reg2::Union{AbstractArrayReg,Pair{<:AbstractArrayReg,<:AbstractBlock}},
39+
)
40+
fidelity2_g(reg1, reg2)
41+
end
42+
3543
function fidelity_g(
3644
reg1::Union{AbstractArrayReg,Pair{<:AbstractArrayReg,<:AbstractBlock}},
3745
reg2::Union{AbstractArrayReg,Pair{<:AbstractArrayReg,<:AbstractBlock}},
@@ -82,6 +90,56 @@ please file an issue if you really need this feature.",
8290
return res1, res2
8391
end
8492

93+
function fidelity2_g(
94+
reg1::Union{AbstractArrayReg,Pair{<:AbstractArrayReg,<:AbstractBlock}},
95+
reg2::Union{AbstractArrayReg,Pair{<:AbstractArrayReg,<:AbstractBlock}},
96+
)
97+
if reg1 isa Pair
98+
in1, c1 = reg1
99+
out1 = copy(in1) |> c1
100+
else
101+
out1 = reg1
102+
end
103+
104+
if reg2 isa Pair
105+
in2, c2 = reg2
106+
out2 = copy(in2) |> c2
107+
else
108+
out2 = reg2
109+
end
110+
if nremain(out1) != 0
111+
throw(
112+
ArgumentError(
113+
"The gradient of registers with environment is not implemented yet.
114+
However, back propagating over a focused register is possible,
115+
please file an issue if you really need this feature.",
116+
),
117+
)
118+
end
119+
overlap = out1' * out2
120+
121+
out1δ = copy(out2)
122+
regscale!.(viewbatch.(Ref(out1δ), 1:length(overlap)), 2.0.*conj.(overlap))
123+
out2δ = copy(out1)
124+
regscale!.(viewbatch.(Ref(out2δ), 1:length(overlap)), 2.0.*overlap)
125+
126+
if reg1 isa Pair
127+
(_, in1δ), params1δ = apply_back((out1, out1δ), c1)
128+
res1 = in1δ => params1δ
129+
else
130+
res1 = out1δ
131+
end
132+
133+
if reg2 isa Pair
134+
(_, in2δ), params2δ = apply_back((out2, out2δ), c2)
135+
res2 = in2δ => params2δ
136+
else
137+
res2 = out2δ
138+
end
139+
140+
return res1, res2
141+
end
142+
85143
function (::Adjoint{Any,typeof(operator_fidelity)})(b1::AbstractBlock, b2::AbstractBlock)
86144
operator_fidelity_g(b1, b2)
87145
end

lib/YaoBlocks/test/autodiff/specializes.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,57 @@ end
105105
#@test isapprox(vec(g2.state), state_numgrad(reg2->sum(fidelity(reg1=>c1, reg2=>c2)), reg2), atol=1e-5)
106106
end
107107

108+
@testset "fidelity2 grad" begin
109+
nbit = 4
110+
Random.seed!(2)
111+
for nbatch in [NoBatch(), 10]
112+
reg1 = rand_state(nbit; nbatch = nbatch)
113+
reg2 = rand_state(nbit; nbatch = nbatch)
114+
c1 = qftcirc(nbit)
115+
c2 = chain(put(nbit, 2 => Rx(0.5)), control(nbit, 1, 3 => Ry(0.5)))
116+
117+
g1, g2 = fidelity2'(reg1, reg2)
118+
@test isapprox(
119+
vec(g1.state),
120+
state_numgrad(reg1 -> sum(fidelity2(reg1, reg2)), reg1),
121+
atol = 1e-4,
122+
)
123+
@test isapprox(
124+
vec(g2.state),
125+
state_numgrad(reg2 -> sum(fidelity2(reg1, reg2)), reg2),
126+
atol = 1e-4,
127+
)
128+
129+
(g1, pg1), (g2, pg2) = fidelity2'(reg1 => c1, reg2 => c2)
130+
npg1 = YaoBlocks.AD.ng(
131+
x -> sum(fidelity2(reg1 => dispatch!(c1, x), reg2 => c2)),
132+
parameters(c1),
133+
)
134+
npg2 = YaoBlocks.AD.ng(
135+
x -> sum(fidelity2(reg1 => c1, reg2 => dispatch!(c2, x))),
136+
parameters(c2),
137+
)
138+
@test isapprox(pg1, vec(npg1), atol = 1e-5)
139+
@test isapprox(pg2, vec(npg2), atol = 1e-5)
140+
@test isapprox(
141+
vec(g1.state),
142+
state_numgrad(reg1 -> sum(fidelity2(reg1 => c1, reg2 => c2)), reg1),
143+
atol = 1e-4,
144+
)
145+
@test isapprox(
146+
vec(g2.state),
147+
state_numgrad(reg2 -> sum(fidelity2(reg1 => c1, reg2 => c2)), reg2),
148+
atol = 1e-4,
149+
)
150+
end
151+
152+
nbatch = NoBatch()
153+
reg1 = rand_state(nbit; nbatch = nbatch) |> focus!(2, 1, 4)
154+
reg2 = rand_state(nbit; nbatch = nbatch) |> focus!(2, 1, 4)
155+
156+
@test_throws ArgumentError fidelity2'(reg1, reg2)
157+
end
158+
108159
@testset "operator fideliy" begin
109160
nbit = 4
110161
c1 = qftcirc(nbit)

0 commit comments

Comments
 (0)