Skip to content

Commit b031dd6

Browse files
david-plweinbe58
andauthored
Implement gemini.logical and draft kernel validation (#558)
This implements the simple dialect group as discussed in #531. However, it also contains a draft (RFC) for validating kernels, using the above as an example. I thought a bit about it and in the end decided to split the validation into two parts: 1. A `ValidationAnalysis` pass with an appropriate lattice (could be extended to warnings, etc.) that checks for errors in the IR. 2. A simple `KernelValidation` class, that gathers up the errors and throws them as `kirin.ir.ValidationError` with the appropriate source info pointing to the line of the error. I went back and forth as to how to implement this, since it is a bit similar to interpretation. However, in the end I just decided to go with a very simple dataclass, since it is sufficiently different from an interpreter or a rewrite. One thing that's not great and I'm not sure how to improve: we probably don't want the `ValidationAnalysis` to implement a lot of methods via method tables. For example, the `gemini.logical` dialect doesn't support `scf.IfElse`. However, if we want to re-use the `ValidationAnalysis` elsewhere, then registering a `MethodTable` for `scf` with a method for `scf.IfElse` might fail validation in kernels where this should be fine. At the same time, we might want to share validation between different kernels. The way I "solved" this for now is to define a new analysis pass just for the logical dialect that inherits from the generic `ValidationAnalysis`. If we go this route, then the `ValidationAnalysis` basically just defines the lattice and a fallback method. There's also some TODOs left here, e.g. how to display multiple errors, but I thought I'd get comments as soon as possible. --------- Co-authored-by: Phillip Weinberg <[email protected]>
1 parent acc29f3 commit b031dd6

File tree

13 files changed

+504
-4
lines changed

13 files changed

+504
-4
lines changed

src/bloqade/gemini/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .groups import logical as logical
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .logical_validation.analysis import (
2+
GeminiLogicalValidationAnalysis as GeminiLogicalValidationAnalysis,
3+
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from . import impls as impls, analysis as analysis # NOTE: register methods
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from kirin import ir
2+
3+
from bloqade import squin
4+
from bloqade.validation.analysis import ValidationFrame, ValidationAnalysis
5+
6+
7+
class GeminiLogicalValidationAnalysis(ValidationAnalysis):
8+
keys = ["gemini.validate.logical"]
9+
10+
first_gate = True
11+
12+
def eval_stmt_fallback(self, frame: ValidationFrame, stmt: ir.Statement):
13+
if isinstance(stmt, squin.gate.stmts.Gate):
14+
# NOTE: to validate that only the first encountered gate can be non-Clifford, we need to track this here
15+
self.first_gate = False
16+
17+
return super().eval_stmt_fallback(frame, stmt)
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
from kirin import ir, interp as _interp
2+
from kirin.analysis import const
3+
from kirin.dialects import scf, func
4+
5+
from bloqade.squin import gate
6+
from bloqade.validation.analysis import ValidationFrame
7+
from bloqade.validation.analysis.lattice import Error
8+
9+
from .analysis import GeminiLogicalValidationAnalysis
10+
11+
12+
@scf.dialect.register(key="gemini.validate.logical")
13+
class __ScfGeminiLogicalValidation(_interp.MethodTable):
14+
15+
@_interp.impl(scf.IfElse)
16+
def if_else(
17+
self,
18+
interp: GeminiLogicalValidationAnalysis,
19+
frame: ValidationFrame,
20+
stmt: scf.IfElse,
21+
):
22+
frame.errors.append(
23+
ir.ValidationError(
24+
stmt, "If statements are not supported in logical Gemini programs!"
25+
)
26+
)
27+
return (
28+
Error(
29+
message="If statements are not supported in logical Gemini programs!"
30+
),
31+
)
32+
33+
@_interp.impl(scf.For)
34+
def for_loop(
35+
self,
36+
interp: GeminiLogicalValidationAnalysis,
37+
frame: ValidationFrame,
38+
stmt: scf.For,
39+
):
40+
if isinstance(stmt.iterable.hints.get("const"), const.Value):
41+
return (interp.lattice.top(),)
42+
43+
frame.errors.append(
44+
ir.ValidationError(
45+
stmt,
46+
"Non-constant iterable in for loop is not supported in Gemini logical programs!",
47+
)
48+
)
49+
50+
return (
51+
Error(
52+
message="Non-constant iterable in for loop is not supported in Gemini logical programs!"
53+
),
54+
)
55+
56+
57+
@func.dialect.register(key="gemini.validate.logical")
58+
class __FuncGeminiLogicalValidation(_interp.MethodTable):
59+
@_interp.impl(func.Invoke)
60+
def invoke(
61+
self,
62+
interp: GeminiLogicalValidationAnalysis,
63+
frame: ValidationFrame,
64+
stmt: func.Invoke,
65+
):
66+
frame.errors.append(
67+
ir.ValidationError(
68+
stmt,
69+
"Function invocations not supported in logical Gemini program!",
70+
help="Make sure to decorate your function with `@logical(inline = True)` or `@logical(aggressive_unroll = True)` to inline function calls",
71+
)
72+
)
73+
74+
return tuple(
75+
Error(
76+
message="Function invocations not supported in logical Gemini program!"
77+
)
78+
for _ in stmt.results
79+
)
80+
81+
82+
@gate.dialect.register(key="gemini.validate.logical")
83+
class __GateGeminiLogicalValidation(_interp.MethodTable):
84+
@_interp.impl(gate.stmts.U3)
85+
def u3(
86+
self,
87+
interp: GeminiLogicalValidationAnalysis,
88+
frame: ValidationFrame,
89+
stmt: gate.stmts.U3,
90+
):
91+
if interp.first_gate:
92+
interp.first_gate = False
93+
return ()
94+
95+
frame.errors.append(
96+
ir.ValidationError(
97+
stmt,
98+
"U3 gate can only be used for initial state preparation, i.e. as the first gate!",
99+
)
100+
)
101+
return ()

src/bloqade/gemini/groups.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from typing import Annotated
2+
3+
from kirin import ir
4+
from kirin.passes import Default
5+
from kirin.prelude import structural_no_opt
6+
from kirin.dialects import py, func, ilist
7+
from typing_extensions import Doc
8+
from kirin.passes.inline import InlinePass
9+
10+
from bloqade.squin import gate, qubit
11+
from bloqade.validation import KernelValidation
12+
from bloqade.rewrite.passes import AggressiveUnroll
13+
14+
from .analysis import GeminiLogicalValidationAnalysis
15+
16+
17+
@ir.dialect_group(structural_no_opt.union([gate, py.constant, qubit, func, ilist]))
18+
def logical(self):
19+
"""Compile a function to a Gemini logical kernel."""
20+
21+
def run_pass(
22+
mt,
23+
*,
24+
verify: Annotated[
25+
bool, Doc("run `verify` before running passes, default is `True`")
26+
] = True,
27+
typeinfer: Annotated[
28+
bool,
29+
Doc("run type inference and apply the inferred type to IR, default `True`"),
30+
] = True,
31+
fold: Annotated[bool, Doc("run folding passes")] = True,
32+
aggressive: Annotated[
33+
bool, Doc("run aggressive folding passes if `fold=True`")
34+
] = False,
35+
inline: Annotated[bool, Doc("inline function calls, default `True`")] = True,
36+
aggressive_unroll: Annotated[
37+
bool,
38+
Doc(
39+
"Run aggressive inlining and unrolling pass on the IR, default `False`"
40+
),
41+
] = False,
42+
no_raise: Annotated[bool, Doc("do not raise exception during analysis")] = True,
43+
) -> None:
44+
45+
if inline and not aggressive_unroll:
46+
InlinePass(mt.dialects, no_raise=no_raise).fixpoint(mt)
47+
48+
if aggressive_unroll:
49+
AggressiveUnroll(mt.dialects, no_raise=no_raise).fixpoint(mt)
50+
else:
51+
default_pass = Default(
52+
self,
53+
verify=verify,
54+
fold=fold,
55+
aggressive=aggressive,
56+
typeinfer=typeinfer,
57+
no_raise=no_raise,
58+
)
59+
60+
default_pass.fixpoint(mt)
61+
62+
if verify:
63+
validator = KernelValidation(GeminiLogicalValidationAnalysis)
64+
validator.run(mt, no_raise=no_raise)
65+
mt.verify()
66+
67+
return run_pass

src/bloqade/squin/gate/stmts.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@
88

99

1010
@statement
11-
class SingleQubitGate(ir.Statement):
11+
class Gate(ir.Statement):
12+
# NOTE: just for easier isinstance checks elsewhere, all gates inherit from this class
13+
pass
14+
15+
16+
@statement
17+
class SingleQubitGate(Gate):
1218
traits = frozenset({lowering.FromPythonCall()})
1319
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
1420

@@ -59,7 +65,7 @@ class SqrtY(SingleQubitNonHermitianGate):
5965

6066

6167
@statement
62-
class RotationGate(ir.Statement):
68+
class RotationGate(Gate):
6369
# NOTE: don't inherit from SingleQubitGate here so the wrapper doesn't have qubits as first arg
6470
traits = frozenset({lowering.FromPythonCall()})
6571
angle: ir.SSAValue = info.argument(types.Float)
@@ -85,7 +91,7 @@ class Rz(RotationGate):
8591

8692

8793
@statement
88-
class ControlledGate(ir.Statement):
94+
class ControlledGate(Gate):
8995
traits = frozenset({lowering.FromPythonCall()})
9096
controls: ir.SSAValue = info.argument(ilist.IListType[QubitType, N])
9197
targets: ir.SSAValue = info.argument(ilist.IListType[QubitType, N])
@@ -110,7 +116,7 @@ class CZ(ControlledGate):
110116

111117

112118
@statement(dialect=dialect)
113-
class U3(ir.Statement):
119+
class U3(Gate):
114120
# NOTE: don't inherit from SingleQubitGate here so the wrapper doesn't have qubits as first arg
115121
traits = frozenset({lowering.FromPythonCall()})
116122
theta: ir.SSAValue = info.argument(types.Float)

src/bloqade/validation/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from . import analysis as analysis
2+
from .kernel_validation import KernelValidation as KernelValidation
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from . import lattice as lattice
2+
from .analysis import (
3+
ValidationFrame as ValidationFrame,
4+
ValidationAnalysis as ValidationAnalysis,
5+
)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from abc import ABC
2+
from dataclasses import field, dataclass
3+
4+
from kirin import ir
5+
from kirin.analysis import ForwardExtra, ForwardFrame
6+
7+
from .lattice import ErrorType
8+
9+
10+
@dataclass
11+
class ValidationFrame(ForwardFrame[ErrorType]):
12+
# NOTE: cannot be set[Error] since that's not hashable
13+
errors: list[ir.ValidationError] = field(default_factory=list)
14+
"""List of all ecnountered errors.
15+
16+
Append a `kirin.ir.ValidationError` to this list in the method implementation
17+
in order for it to get picked up by the `KernelValidation` run.
18+
"""
19+
20+
21+
@dataclass
22+
class ValidationAnalysis(ForwardExtra[ValidationFrame, ErrorType], ABC):
23+
"""Analysis pass that indicates errors in the IR according to the respective method tables.
24+
25+
If you need to implement validation for a dialect shared by many groups (for example, if you need to ascertain if statements have a specific form)
26+
you'll need to inherit from this class.
27+
"""
28+
29+
lattice = ErrorType
30+
31+
def run_method(self, method: ir.Method, args: tuple[ErrorType, ...]):
32+
return self.run_callable(method.code, (self.lattice.top(),) + args)
33+
34+
def eval_stmt_fallback(self, frame: ValidationFrame, stmt: ir.Statement):
35+
# NOTE: default to no errors
36+
return tuple(self.lattice.top() for _ in stmt.results)
37+
38+
def initialize_frame(
39+
self, code: ir.Statement, *, has_parent_access: bool = False
40+
) -> ValidationFrame:
41+
return ValidationFrame(code, has_parent_access=has_parent_access)

0 commit comments

Comments
 (0)