Skip to content

Commit b422d66

Browse files
committed
Fix IfElse impl and test known branch execution
1 parent 9d0c67c commit b422d66

File tree

3 files changed

+76
-20
lines changed

3 files changed

+76
-20
lines changed

src/bloqade/analysis/fidelity/analysis.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,17 @@ class FidelityAnalysis(AddressAnalysis):
1919
## Usage examples
2020
2121
```
22-
from bloqade import qasm2
23-
from bloqade.noise import native
24-
from bloqade.analysis.fidelity import FidelityAnalysis
25-
from bloqade.qasm2.passes.noise import NoisePass
22+
from bloqade import squin
2623
27-
noise_main = qasm2.extended.add(native.dialect)
28-
29-
@noise_main
24+
@squin.kernel
3025
def main():
31-
q = qasm2.qreg(2)
32-
qasm2.x(q[0])
26+
q = squin.qalloc(1)
27+
squin.x(q[0])
28+
squin.depolarize(q[0])
3329
return q
3430
35-
NoisePass(main.dialects)(main)
36-
3731
fid_analysis = FidelityAnalysis(main.dialects)
38-
fid_analysis.run_analysis(main, no_raise=False)
32+
fid_analysis.run(main)
3933
4034
gate_fidelities = fid_analysis.gate_fidelities
4135
qubit_survival_probs = fid_analysis.qubit_survival_fidelities

src/bloqade/analysis/fidelity/impls.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from kirin import interp
2-
from kirin.analysis import ForwardFrame
2+
from kirin.analysis import ForwardFrame, const
33
from kirin.dialects import scf
44

5-
from bloqade.analysis.address import Address
5+
from bloqade.analysis.address import Address, ConstResult
66

77
from .analysis import FidelityAnalysis
88

@@ -18,17 +18,27 @@ def if_else(
1818
current_gate_fidelities = interp_.gate_fidelities
1919
current_survival_fidelities = interp_.qubit_survival_fidelities
2020

21-
# TODO: check if the condition is constant and fix the branch in that case
22-
# run both branches
21+
address_cond = frame.get(stmt.cond)
22+
23+
# NOTE: if the condition is known at compile time, run specific branch
24+
if isinstance(address_cond, ConstResult) and isinstance(
25+
const_cond := address_cond.result, const.Value
26+
):
27+
body = stmt.then_body if const_cond.data else stmt.else_body
28+
with interp_.new_frame(stmt, has_parent_access=True) as body_frame:
29+
ret = interp_.frame_call_region(body_frame, stmt, body, address_cond)
30+
return ret
31+
32+
# NOTE: runtime condition, evaluate both
2333
with interp_.new_frame(stmt, has_parent_access=True) as then_frame:
2434
# NOTE: reset fidelities before stepping into the then-body
2535
interp_.reset_fidelities()
2636

27-
interp_.frame_call_region(
37+
then_results = interp_.frame_call_region(
2838
then_frame,
2939
stmt,
3040
stmt.then_body,
31-
*(interp_.lattice.bottom() for _ in range(len(stmt.args))),
41+
address_cond,
3242
)
3343
then_fids = interp_.gate_fidelities
3444
then_survival = interp_.qubit_survival_fidelities
@@ -37,11 +47,11 @@ def if_else(
3747
# NOTE: reset again before stepping into else-body
3848
interp_.reset_fidelities()
3949

40-
interp_.frame_call_region(
50+
else_results = interp_.frame_call_region(
4151
else_frame,
4252
stmt,
4353
stmt.else_body,
44-
*(interp_.lattice.bottom() for _ in range(len(stmt.args))),
54+
address_cond,
4555
)
4656

4757
else_fids = interp_.gate_fidelities
@@ -60,3 +70,17 @@ def if_else(
6070
then_survival,
6171
else_survival,
6272
)
73+
74+
# TODO: pick the non-return value
75+
if isinstance(then_results, interp.ReturnValue) and isinstance(
76+
else_results, interp.ReturnValue
77+
):
78+
return interp.ReturnValue(then_results.value.join(else_results.value))
79+
elif isinstance(then_results, interp.ReturnValue):
80+
ret = else_results
81+
elif isinstance(else_results, interp.ReturnValue):
82+
ret = then_results
83+
else:
84+
ret = interp_.join_results(then_results, else_results)
85+
86+
return ret

test/analysis/fidelity/test_fidelity.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,3 +344,41 @@ def main():
344344
]
345345
+ [FidelityRange(0.87, 0.87)] * 4
346346
) # squin.correlated_qubit_loss
347+
348+
349+
def test_squin_know_if():
350+
@squin.kernel
351+
def main():
352+
x = True
353+
q = squin.qalloc(4)
354+
355+
if x:
356+
squin.depolarize(0.1, q[0])
357+
squin.qubit_loss(0.1, q[0])
358+
else:
359+
squin.depolarize(0.1, q[1])
360+
squin.qubit_loss(0.1, q[1])
361+
362+
if not x:
363+
squin.depolarize(0.2, q[2])
364+
squin.qubit_loss(0.2, q[2])
365+
else:
366+
squin.depolarize(0.2, q[3])
367+
squin.qubit_loss(0.2, q[3])
368+
369+
fidelity_analysis = FidelityAnalysis(main.dialects)
370+
fidelity_analysis.run(main)
371+
372+
assert fidelity_analysis.gate_fidelities == [
373+
FidelityRange(0.9, 0.9),
374+
FidelityRange(1.0, 1.0),
375+
FidelityRange(1.0, 1.0),
376+
FidelityRange(0.8, 0.8),
377+
]
378+
379+
assert fidelity_analysis.qubit_survival_fidelities == [
380+
FidelityRange(0.9, 0.9),
381+
FidelityRange(1.0, 1.0),
382+
FidelityRange(1.0, 1.0),
383+
FidelityRange(0.8, 0.8),
384+
]

0 commit comments

Comments
 (0)