Skip to content

Commit aaa2ed7

Browse files
authored
Fix issue 545 - autodiff with batched-arrayreg (#547)
1 parent 9a0e313 commit aaa2ed7

File tree

2 files changed

+38
-3
lines changed

2 files changed

+38
-3
lines changed

lib/YaoBlocks/src/autodiff/chainrules_patch.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,18 +79,19 @@ end
7979
function rrule(::typeof(apply), reg::AbstractArrayReg, block::AbstractBlock)
8080
out = apply(reg, block)
8181
out, function (outδ)
82-
(in, inδ), paramsδ = apply_back((copy(out), tangent_to_reg(typeof(out), outδ)), block)
82+
(in, inδ), paramsδ = apply_back((copy(out), tangent_to_reg(out, outδ)), block)
8383
return (NoTangent(), inδ, create_circuit_tangent(block, paramsδ))
8484
end
8585
end
8686
function rrule(::typeof(apply), reg::AbstractArrayReg, block::AbstractAdd)
8787
out = apply(reg, block)
8888
out, function (outδ)
89-
(in, inδ), paramsδ = apply_back((copy(out), tangent_to_reg(typeof(out), outδ)), block; in = reg)
89+
(in, inδ), paramsδ = apply_back((copy(out), tangent_to_reg(out, outδ)), block; in = reg)
9090
return (NoTangent(), inδ, create_circuit_tangent(block, paramsδ))
9191
end
9292
end
93-
tangent_to_reg(::Type{T}, reg) where T<:AbstractArrayReg = reg isa Tangent ? T(reg.state) : reg
93+
tangent_to_reg(reg::T, grad) where {D, T<:ArrayReg{D}} = grad isa Tangent ? ArrayReg{D}(grad.state) : grad
94+
tangent_to_reg(reg::T, grad) where {D, T<:BatchedArrayReg{D}} = grad isa Tangent ? BatchedArrayReg{D}(grad.state, reg.nbatch) : grad
9495

9596

9697
function rrule(::typeof(dispatch), block::AbstractBlock, params)

lib/YaoBlocks/test/autodiff/chainrules_patch.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,4 +285,38 @@ end
285285
loss3(c, rand(nparameters(c)))
286286
zygote_grad = Zygote.gradient(θs->loss3(c, θs), rand(nparameters(c)))[1]
287287
@test zygote_grad isa Vector
288+
end
289+
290+
@testset "issue #545" begin
291+
function variational_circuit(q, nlayer)
292+
circuit = chain(q)
293+
for i = 1:nlayer
294+
push!(circuit, put(q, i=>chain(Rx(0.0), Ry(0.0), Rz(0.0))))
295+
end
296+
circuit
297+
end
298+
function bug_mwe()
299+
q = 4
300+
vqc = variational_circuit(q, 2)
301+
param = 4pi*rand(nparameters(vqc))
302+
303+
# bstate = rand_state(q) # works OK
304+
bstate = rand_state(q, nbatch=2) # Zygote errors
305+
306+
function loss(param)
307+
vqc_p = dispatch(vqc, param)
308+
bstate_p = apply(bstate, vqc_p)
309+
pr = probs(bstate_p)
310+
pr_s = sum(pr, dims = 2)
311+
r = -sum(pr_s[1:2^(4-1)])
312+
return r
313+
end
314+
315+
l = loss(param)
316+
grad = Zygote.gradient(loss, param)[1]
317+
318+
return grad
319+
end
320+
321+
@test bug_mwe() isa Vector
288322
end

0 commit comments

Comments
 (0)