Skip to content

Commit 9ff1d99

Browse files
committed
Merge branch 'david/571-kirin-upgrade-branch' into david/574-upgrade-interpreters
2 parents 17e8ffb + 435b33b commit 9ff1d99

File tree

15 files changed

+518
-10
lines changed

15 files changed

+518
-10
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "bloqade-circuit"
3-
version = "0.8.0-DEV"
3+
version = "0.9.0-DEV"
44
description = "The software development toolkit for neutral atom arrays."
55
readme = "README.md"
66
authors = [

src/bloqade/gemini/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .groups import logical as logical
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .logical_validation.analysis import (
2+
GeminiLogicalValidationAnalysis as GeminiLogicalValidationAnalysis,
3+
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from . import impls as impls, analysis as analysis # NOTE: register methods
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from kirin import ir
2+
3+
from bloqade import squin
4+
from bloqade.validation.analysis import ValidationFrame, ValidationAnalysis
5+
6+
7+
class GeminiLogicalValidationAnalysis(ValidationAnalysis):
8+
keys = ["gemini.validate.logical"]
9+
10+
first_gate = True
11+
12+
def eval_stmt_fallback(self, frame: ValidationFrame, stmt: ir.Statement):
13+
if isinstance(stmt, squin.gate.stmts.Gate):
14+
# NOTE: to validate that only the first encountered gate can be non-Clifford, we need to track this here
15+
self.first_gate = False
16+
17+
return super().eval_stmt_fallback(frame, stmt)
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
from kirin import ir, interp as _interp
2+
from kirin.analysis import const
3+
from kirin.dialects import scf, func
4+
5+
from bloqade.squin import gate
6+
from bloqade.validation.analysis import ValidationFrame
7+
from bloqade.validation.analysis.lattice import Error
8+
9+
from .analysis import GeminiLogicalValidationAnalysis
10+
11+
12+
@scf.dialect.register(key="gemini.validate.logical")
13+
class __ScfGeminiLogicalValidation(_interp.MethodTable):
14+
15+
@_interp.impl(scf.IfElse)
16+
def if_else(
17+
self,
18+
interp: GeminiLogicalValidationAnalysis,
19+
frame: ValidationFrame,
20+
stmt: scf.IfElse,
21+
):
22+
frame.errors.append(
23+
ir.ValidationError(
24+
stmt, "If statements are not supported in logical Gemini programs!"
25+
)
26+
)
27+
return (
28+
Error(
29+
message="If statements are not supported in logical Gemini programs!"
30+
),
31+
)
32+
33+
@_interp.impl(scf.For)
34+
def for_loop(
35+
self,
36+
interp: GeminiLogicalValidationAnalysis,
37+
frame: ValidationFrame,
38+
stmt: scf.For,
39+
):
40+
if isinstance(stmt.iterable.hints.get("const"), const.Value):
41+
return (interp.lattice.top(),)
42+
43+
frame.errors.append(
44+
ir.ValidationError(
45+
stmt,
46+
"Non-constant iterable in for loop is not supported in Gemini logical programs!",
47+
)
48+
)
49+
50+
return (
51+
Error(
52+
message="Non-constant iterable in for loop is not supported in Gemini logical programs!"
53+
),
54+
)
55+
56+
57+
@func.dialect.register(key="gemini.validate.logical")
58+
class __FuncGeminiLogicalValidation(_interp.MethodTable):
59+
@_interp.impl(func.Invoke)
60+
def invoke(
61+
self,
62+
interp: GeminiLogicalValidationAnalysis,
63+
frame: ValidationFrame,
64+
stmt: func.Invoke,
65+
):
66+
frame.errors.append(
67+
ir.ValidationError(
68+
stmt,
69+
"Function invocations not supported in logical Gemini program!",
70+
help="Make sure to decorate your function with `@logical(inline = True)` or `@logical(aggressive_unroll = True)` to inline function calls",
71+
)
72+
)
73+
74+
return tuple(
75+
Error(
76+
message="Function invocations not supported in logical Gemini program!"
77+
)
78+
for _ in stmt.results
79+
)
80+
81+
82+
@gate.dialect.register(key="gemini.validate.logical")
83+
class __GateGeminiLogicalValidation(_interp.MethodTable):
84+
@_interp.impl(gate.stmts.U3)
85+
def u3(
86+
self,
87+
interp: GeminiLogicalValidationAnalysis,
88+
frame: ValidationFrame,
89+
stmt: gate.stmts.U3,
90+
):
91+
if interp.first_gate:
92+
interp.first_gate = False
93+
return ()
94+
95+
frame.errors.append(
96+
ir.ValidationError(
97+
stmt,
98+
"U3 gate can only be used for initial state preparation, i.e. as the first gate!",
99+
)
100+
)
101+
return ()

src/bloqade/gemini/groups.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from typing import Annotated
2+
3+
from kirin import ir
4+
from kirin.passes import Default
5+
from kirin.prelude import structural_no_opt
6+
from kirin.dialects import py, func, ilist
7+
from typing_extensions import Doc
8+
from kirin.passes.inline import InlinePass
9+
10+
from bloqade.squin import gate, qubit
11+
from bloqade.validation import KernelValidation
12+
from bloqade.rewrite.passes import AggressiveUnroll
13+
14+
from .analysis import GeminiLogicalValidationAnalysis
15+
16+
17+
@ir.dialect_group(structural_no_opt.union([gate, py.constant, qubit, func, ilist]))
18+
def logical(self):
19+
"""Compile a function to a Gemini logical kernel."""
20+
21+
def run_pass(
22+
mt,
23+
*,
24+
verify: Annotated[
25+
bool, Doc("run `verify` before running passes, default is `True`")
26+
] = True,
27+
typeinfer: Annotated[
28+
bool,
29+
Doc("run type inference and apply the inferred type to IR, default `True`"),
30+
] = True,
31+
fold: Annotated[bool, Doc("run folding passes")] = True,
32+
aggressive: Annotated[
33+
bool, Doc("run aggressive folding passes if `fold=True`")
34+
] = False,
35+
inline: Annotated[bool, Doc("inline function calls, default `True`")] = True,
36+
aggressive_unroll: Annotated[
37+
bool,
38+
Doc(
39+
"Run aggressive inlining and unrolling pass on the IR, default `False`"
40+
),
41+
] = False,
42+
no_raise: Annotated[bool, Doc("do not raise exception during analysis")] = True,
43+
) -> None:
44+
45+
if inline and not aggressive_unroll:
46+
InlinePass(mt.dialects, no_raise=no_raise).fixpoint(mt)
47+
48+
if aggressive_unroll:
49+
AggressiveUnroll(mt.dialects, no_raise=no_raise).fixpoint(mt)
50+
else:
51+
default_pass = Default(
52+
self,
53+
verify=verify,
54+
fold=fold,
55+
aggressive=aggressive,
56+
typeinfer=typeinfer,
57+
no_raise=no_raise,
58+
)
59+
60+
default_pass.fixpoint(mt)
61+
62+
if verify:
63+
validator = KernelValidation(GeminiLogicalValidationAnalysis)
64+
validator.run(mt, no_raise=no_raise)
65+
mt.verify()
66+
67+
return run_pass

src/bloqade/rewrite/passes/aggressive_unroll.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414
InlineGetItem,
1515
InlineGetField,
1616
DeadCodeElimination,
17+
CommonSubexpressionElimination,
1718
)
1819
from kirin.dialects import scf, ilist
1920
from kirin.ir.method import Method
2021
from kirin.rewrite.abc import RewriteResult
21-
from kirin.rewrite.cse import CommonSubexpressionElimination
2222
from kirin.passes.aggressive import UnrollScf
2323

24+
from .canonicalize_ilist import CanonicalizeIList
25+
2426

2527
@dataclass
2628
class Fold(Pass):
@@ -55,30 +57,36 @@ class AggressiveUnroll(Pass):
5557
fold: Fold = field(init=False)
5658
typeinfer: TypeInfer = field(init=False)
5759
scf_unroll: UnrollScf = field(init=False)
60+
canonicalize_ilist: CanonicalizeIList = field(init=False)
5861

5962
def __post_init__(self):
6063
self.fold = Fold(self.dialects, no_raise=self.no_raise)
6164
self.typeinfer = TypeInfer(self.dialects, no_raise=self.no_raise)
6265
self.scf_unroll = UnrollScf(self.dialects, no_raise=self.no_raise)
66+
self.canonicalize_ilist = CanonicalizeIList(
67+
self.dialects, no_raise=self.no_raise
68+
)
6369

6470
def unsafe_run(self, mt: Method) -> RewriteResult:
6571
result = RewriteResult()
72+
result = self.fold.unsafe_run(mt).join(result)
6673
result = self.scf_unroll.unsafe_run(mt).join(result)
74+
self.typeinfer.unsafe_run(
75+
mt
76+
) # Do not join the result of typeinfer or fixpoint will waste time
6777
result = (
6878
Walk(Chain(ilist.rewrite.ConstList2IList(), ilist.rewrite.Unroll()))
6979
.rewrite(mt.code)
7080
.join(result)
7181
)
72-
self.typeinfer.unsafe_run(mt)
73-
result = self.fold.unsafe_run(mt).join(result)
7482
result = Walk(Inline(self.inline_heuristic)).rewrite(mt.code).join(result)
7583
result = Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(result)
76-
84+
result = self.canonicalize_ilist.fixpoint(mt).join(result)
7785
rule = Chain(
7886
CommonSubexpressionElimination(),
7987
DeadCodeElimination(),
8088
)
81-
result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result)
89+
result = Walk(rule).rewrite(mt.code).join(result)
8290

8391
return result
8492

src/bloqade/squin/gate/stmts.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@
88

99

1010
@statement
11-
class SingleQubitGate(ir.Statement):
11+
class Gate(ir.Statement):
12+
# NOTE: just for easier isinstance checks elsewhere, all gates inherit from this class
13+
pass
14+
15+
16+
@statement
17+
class SingleQubitGate(Gate):
1218
traits = frozenset({lowering.FromPythonCall()})
1319
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
1420

@@ -59,7 +65,7 @@ class SqrtY(SingleQubitNonHermitianGate):
5965

6066

6167
@statement
62-
class RotationGate(ir.Statement):
68+
class RotationGate(Gate):
6369
# NOTE: don't inherit from SingleQubitGate here so the wrapper doesn't have qubits as first arg
6470
traits = frozenset({lowering.FromPythonCall()})
6571
angle: ir.SSAValue = info.argument(types.Float)
@@ -85,7 +91,7 @@ class Rz(RotationGate):
8591

8692

8793
@statement
88-
class ControlledGate(ir.Statement):
94+
class ControlledGate(Gate):
8995
traits = frozenset({lowering.FromPythonCall()})
9096
controls: ir.SSAValue = info.argument(ilist.IListType[QubitType, N])
9197
targets: ir.SSAValue = info.argument(ilist.IListType[QubitType, N])
@@ -110,7 +116,7 @@ class CZ(ControlledGate):
110116

111117

112118
@statement(dialect=dialect)
113-
class U3(ir.Statement):
119+
class U3(Gate):
114120
# NOTE: don't inherit from SingleQubitGate here so the wrapper doesn't have qubits as first arg
115121
traits = frozenset({lowering.FromPythonCall()})
116122
theta: ir.SSAValue = info.argument(types.Float)

src/bloqade/validation/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from . import analysis as analysis
2+
from .kernel_validation import KernelValidation as KernelValidation

0 commit comments

Comments
 (0)