Skip to content

Commit 2bc7cf8

Browse files
authored
Quantum runtime analysis (#68)
* adding analysis to check for quantum runtime instructions * refactor analysis * WIP: runtime analysis * adding statements to frame * removing unused imporst * Removing runtime impl this is not a quantum runtime * adding support for scf, and func dialects * adding run_method + importing method tables * Adding tests + fixing some minor bugs * Adding simplified interface * testing methods of other dialects + bug fixes * switching to using interpreters lattice property
1 parent 5e4eae8 commit 2bc7cf8

File tree

11 files changed

+431
-5
lines changed

11 files changed

+431
-5
lines changed
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
from dataclasses import dataclass, field
2+
3+
from kirin import interp, ir
4+
from kirin.analysis import ForwardExtra, ForwardFrame, const
5+
from kirin.dialects import func, scf
6+
from kirin.lattice import EmptyLattice
7+
8+
9+
@dataclass
10+
class RuntimeFrame(ForwardFrame[EmptyLattice]):
11+
"""Frame for quantum runtime analysis.
12+
This frame is used to track the state of quantum operations within a method.
13+
"""
14+
15+
quantum_stmts: set[ir.Statement] = field(default_factory=set)
16+
"""Set of quantum statements in the frame."""
17+
is_quantum: bool = False
18+
"""Whether the frame contains quantum operations."""
19+
20+
21+
class RuntimeAnalysis(ForwardExtra[RuntimeFrame, EmptyLattice]):
22+
"""Forward dataflow analysis to check if a method has quantum runtime.
23+
24+
This analysis checks if a method contains any quantum runtime operations.
25+
It is used to determine if the method can be executed on a quantum device.
26+
"""
27+
28+
keys = ["runtime"]
29+
lattice = EmptyLattice
30+
31+
def eval_stmt_fallback(self, frame: RuntimeFrame, stmt: ir.Statement):
32+
return tuple(self.lattice.top() for _ in stmt.results)
33+
34+
def initialize_frame(
35+
self, code: ir.Statement, *, has_parent_access: bool = False
36+
) -> RuntimeFrame:
37+
return RuntimeFrame(code, has_parent_access=has_parent_access)
38+
39+
def run_method(self, method: ir.Method, args: tuple[EmptyLattice, ...]):
40+
return self.run_callable(method.code, (self.lattice.bottom(),) + args)
41+
42+
def has_quantum_runtime(self, method: ir.Method) -> bool:
43+
"""Return True if the method has quantum runtime operations, False otherwise."""
44+
frame, _ = self.run_analysis(method)
45+
return frame.is_quantum
46+
47+
48+
@scf.dialect.register(key="runtime")
49+
class Scf(interp.MethodTable):
50+
51+
@interp.impl(scf.IfElse)
52+
def ifelse(self, _interp: RuntimeAnalysis, frame: RuntimeFrame, stmt: scf.IfElse):
53+
# If either branch is quantum, the whole ifelse is quantum
54+
with _interp.new_frame(stmt, has_parent_access=True) as then_frame:
55+
then_result = _interp.run_ssacfg_region(
56+
then_frame, stmt.then_body, (_interp.lattice.top(),)
57+
)
58+
59+
with _interp.new_frame(stmt, has_parent_access=True) as else_frame:
60+
else_result = _interp.run_ssacfg_region(
61+
else_frame, stmt.else_body, (_interp.lattice.top(),)
62+
)
63+
64+
frame.is_quantum = (
65+
frame.is_quantum or then_frame.is_quantum or else_frame.is_quantum
66+
)
67+
frame.quantum_stmts.update(then_frame.quantum_stmts, else_frame.quantum_stmts)
68+
match (then_result, else_result):
69+
case (interp.ReturnValue(), tuple()):
70+
return else_result
71+
case (tuple(), interp.ReturnValue()):
72+
return then_result
73+
case (tuple(), tuple()):
74+
return tuple(
75+
then_result.join(else_result)
76+
for then_result, else_result in zip(then_result, else_result)
77+
)
78+
case _:
79+
return tuple(_interp.lattice.top() for _ in stmt.results)
80+
81+
@interp.impl(scf.For)
82+
def for_loop(self, _interp: RuntimeAnalysis, frame: RuntimeFrame, stmt: scf.For):
83+
args = (_interp.lattice.top(),) * (len(stmt.initializers) + 1)
84+
with _interp.new_frame(stmt, has_parent_access=True) as body_frame:
85+
result = _interp.run_ssacfg_region(
86+
body_frame, stmt.body, (_interp.lattice.bottom(),)
87+
)
88+
89+
frame.is_quantum = frame.is_quantum or body_frame.is_quantum
90+
frame.quantum_stmts.update(body_frame.quantum_stmts)
91+
if isinstance(result, interp.ReturnValue) or result is None:
92+
return args[1:]
93+
else:
94+
return tuple(arg.join(res) for arg, res in zip(args[1:], result))
95+
96+
@interp.impl(scf.Yield)
97+
def yield_stmt(
98+
self, _interp: RuntimeAnalysis, frame: RuntimeFrame, stmt: scf.Yield
99+
):
100+
return interp.YieldValue(frame.get_values(stmt.args))
101+
102+
103+
@func.dialect.register(key="runtime")
104+
class Func(interp.MethodTable):
105+
106+
@interp.impl(func.Invoke)
107+
def invoke(self, _interp: RuntimeAnalysis, frame: RuntimeFrame, stmt: func.Invoke):
108+
args = (_interp.lattice.top(),) * len(stmt.inputs)
109+
callee_frame, result = _interp.run_method(stmt.callee, args)
110+
frame.is_quantum = frame.is_quantum or callee_frame.is_quantum
111+
return (result,)
112+
113+
@interp.impl(func.Call)
114+
def call(self, _interp: RuntimeAnalysis, frame: RuntimeFrame, stmt: func.Call):
115+
# Check if the called method is quantum
116+
callee_result = stmt.callee.hints.get("const")
117+
args = (_interp.lattice.top(),) * len(stmt.inputs)
118+
if (
119+
isinstance(callee_result, const.PartialLambda)
120+
and (trait := callee_result.code.get_trait(ir.CallableStmtInterface))
121+
is not None
122+
):
123+
body = trait.get_callable_region(callee_result.code)
124+
with _interp.new_frame(stmt) as callee_frame:
125+
result = _interp.run_ssacfg_region(callee_frame, body, args)
126+
else:
127+
raise InterruptedError("Dynamic method calls are not supported")
128+
129+
frame.is_quantum = frame.is_quantum or callee_frame.is_quantum
130+
return (result,)
131+
132+
@interp.impl(func.Return)
133+
def return_stmt(
134+
self, _interp: RuntimeAnalysis, frame: RuntimeFrame, stmt: func.Return
135+
):
136+
return interp.ReturnValue(frame.get_values(stmt.results))

src/bloqade/shuttle/dialects/gate/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from . import runtime as runtime
12
from ._dialect import dialect as dialect
23
from ._interface import (
34
global_r as global_r,
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from kirin import interp
2+
3+
from bloqade.shuttle.analysis.runtime import (
4+
RuntimeAnalysis,
5+
RuntimeFrame,
6+
)
7+
8+
from ._dialect import dialect
9+
from .stmts import GlobalR, GlobalRz, LocalR, LocalRz, TopHatCZ
10+
11+
12+
@dialect.register(key="runtime")
13+
class HasQuantumRuntimeMethodTable(interp.MethodTable):
14+
15+
@interp.impl(TopHatCZ)
16+
@interp.impl(LocalRz)
17+
@interp.impl(LocalR)
18+
@interp.impl(GlobalR)
19+
@interp.impl(GlobalRz)
20+
def gate(
21+
self,
22+
interp: RuntimeAnalysis,
23+
frame: RuntimeFrame,
24+
stmt: TopHatCZ | LocalRz | LocalR | GlobalR | GlobalRz,
25+
) -> interp.StatementResult[RuntimeFrame]:
26+
"""Handle gate statements and mark the frame as quantum."""
27+
frame.is_quantum = True
28+
frame.quantum_stmts.add(stmt)
29+
return ()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from . import runtime as runtime
12
from ._dialect import dialect as dialect
23
from ._interface import fill as fill
34
from .stmts import Fill as Fill
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from kirin import interp
2+
3+
from bloqade.shuttle.analysis.runtime import (
4+
RuntimeAnalysis,
5+
RuntimeFrame,
6+
)
7+
8+
from ._dialect import dialect
9+
from .stmts import Fill
10+
11+
12+
@dialect.register(key="runtime")
13+
class HasQuantumRuntimeMethodTable(interp.MethodTable):
14+
15+
@interp.impl(Fill)
16+
def gate(self, interp: RuntimeAnalysis, frame: RuntimeFrame, stmt: Fill):
17+
"""Handle gate statements and mark the frame as quantum."""
18+
frame.is_quantum = True
19+
frame.quantum_stmts.add(stmt)
20+
return ()

src/bloqade/shuttle/dialects/init/stmts.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,7 @@
1010
class Fill(ir.Statement):
1111
name = "fill"
1212

13-
traits = frozenset(
14-
{
15-
lowering.FromPythonCall(),
16-
}
17-
)
13+
traits = frozenset({lowering.FromPythonCall()})
1814
locations: ir.SSAValue = info.argument(
1915
ilist.IListType[grid.GridType[types.Any, types.Any], types.Any]
2016
)

src/bloqade/shuttle/dialects/measure/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from . import runtime as runtime
12
from ._dialect import dialect as dialect
23
from ._interface import measure as measure
34
from .stmts import Measure as Measure, New as New
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from kirin import interp
2+
3+
from bloqade.shuttle.analysis.runtime import (
4+
RuntimeAnalysis,
5+
RuntimeFrame,
6+
)
7+
8+
from ._dialect import dialect
9+
from .stmts import Measure
10+
11+
12+
@dialect.register(key="runtime")
13+
class HasQuantumRuntimeMethodTable(interp.MethodTable):
14+
15+
@interp.impl(Measure)
16+
def gate(self, _interp: RuntimeAnalysis, frame: RuntimeFrame, stmt: Measure):
17+
"""Handle gate statements and mark the frame as quantum."""
18+
frame.is_quantum = True
19+
frame.quantum_stmts.add(stmt)
20+
return (_interp.lattice.top(),)

src/bloqade/shuttle/dialects/path/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from . import runtime as runtime
12
from ._dialect import dialect as dialect
23
from .concrete import PathInterpreter as PathInterpreter
34
from .constprop import ConstProp as ConstProp
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from kirin import interp
2+
3+
from bloqade.shuttle.analysis.runtime import (
4+
RuntimeAnalysis,
5+
RuntimeFrame,
6+
)
7+
8+
from ._dialect import dialect
9+
from .stmts import Play
10+
11+
12+
@dialect.register(key="runtime")
13+
class HasQuantumRuntimeMethodTable(interp.MethodTable):
14+
15+
@interp.impl(Play)
16+
def gate(
17+
self, interp: RuntimeAnalysis, frame: RuntimeFrame, stmt: Play
18+
) -> interp.StatementResult[RuntimeFrame]:
19+
"""Handle gate statements and mark the frame as quantum."""
20+
frame.is_quantum = True
21+
frame.quantum_stmts.add(stmt)
22+
return ()

0 commit comments

Comments
 (0)