Skip to content

Commit ef935d0

Browse files
committed
Validation for QASM2 main kernels
1 parent d29a1c3 commit ef935d0

File tree

8 files changed

+169
-27
lines changed

8 files changed

+169
-27
lines changed

src/bloqade/qasm2/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
emit as emit,
55
glob as glob,
66
parse as parse,
7+
analysis as analysis,
78
dialects as dialects,
89
parallel as parallel,
910
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .validation.analysis import QASM2Validation as QASM2Validation

src/bloqade/qasm2/analysis/validation/__init__.py

Whitespace-only changes.
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
from typing import Any
2+
3+
from kirin import ir, interp
4+
from kirin.lattice import EmptyLattice
5+
from kirin.analysis import Forward
6+
from kirin.dialects import scf
7+
from kirin.validation import ValidationPass
8+
from kirin.analysis.forward import ForwardFrame
9+
10+
from bloqade.qasm2.passes.unroll_if import DontLiftType
11+
12+
13+
class _QASM2ValidationAnalysis(Forward[EmptyLattice]):
14+
keys = ["qasm2.main.validation"]
15+
16+
lattice = EmptyLattice
17+
18+
def method_self(self, method: ir.Method) -> EmptyLattice:
19+
return self.lattice.bottom()
20+
21+
def eval_fallback(
22+
self, frame: ForwardFrame[EmptyLattice], node: ir.Statement
23+
) -> tuple[EmptyLattice, ...]:
24+
return tuple(self.lattice.bottom() for _ in range(len(node.results)))
25+
26+
27+
@scf.dialect.register(key="qasm2.main.validation")
28+
class __ScfMethods(interp.MethodTable):
29+
30+
@interp.impl(scf.IfElse)
31+
def if_else(
32+
self,
33+
interp_: _QASM2ValidationAnalysis,
34+
frame: ForwardFrame[EmptyLattice],
35+
stmt: scf.IfElse,
36+
):
37+
38+
# TODO: stmt.condition has to be based off a measurement
39+
40+
if len(stmt.then_body.blocks) > 1:
41+
interp_.add_validation_error(
42+
stmt,
43+
ir.ValidationError(
44+
stmt,
45+
"Only single block is allowed in the then-body of an if-else statement!",
46+
),
47+
)
48+
49+
then_stmts = list(stmt.then_body.stmts())
50+
if len(then_stmts) > 2:
51+
interp_.add_validation_error(
52+
stmt,
53+
ir.ValidationError(
54+
stmt,
55+
"Only single statements are allowed inside the then-body of an if-else statement!",
56+
),
57+
)
58+
59+
if not isinstance(then_stmts[0], DontLiftType):
60+
interp_.add_validation_error(
61+
stmt,
62+
ir.ValidationError(
63+
stmt, f"Statement {then_stmts[0]} not allowed inside if clause!"
64+
),
65+
)
66+
67+
self.__validate_empty_yield(interp_, then_stmts[-1])
68+
69+
if len(stmt.else_body.blocks) > 1:
70+
interp_.add_validation_error(
71+
stmt,
72+
ir.ValidationError(
73+
stmt,
74+
"Only single block is allowed in the else-body of an if-else statement!",
75+
),
76+
)
77+
78+
else_stmts = list(stmt.else_body.stmts())
79+
if len(else_stmts) > 1:
80+
interp_.add_validation_error(
81+
stmt,
82+
ir.ValidationError(stmt, "Non-empty else is not allowed in QASM2!"),
83+
)
84+
85+
self.__validate_empty_yield(interp_, else_stmts[-1])
86+
87+
def __validate_empty_yield(
88+
self, interp_: _QASM2ValidationAnalysis, stmt: ir.Statement
89+
):
90+
if not isinstance(stmt, scf.Yield):
91+
interp_.add_validation_error(
92+
stmt,
93+
ir.ValidationError(
94+
stmt, f"Expected scf.Yield terminator in if clause, got {stmt}"
95+
),
96+
)
97+
elif len(stmt.values) > 0:
98+
interp_.add_validation_error(
99+
stmt, ir.ValidationError(stmt, "Cannot yield values from if statement!")
100+
)
101+
102+
@interp.impl(scf.For)
103+
def for_loop(
104+
self,
105+
interp_: _QASM2ValidationAnalysis,
106+
frame: ForwardFrame[EmptyLattice],
107+
stmt: scf.For,
108+
):
109+
interp_.add_validation_error(
110+
stmt, ir.ValidationError(stmt, "Loops not supported in QASM2!")
111+
)
112+
113+
114+
class QASM2Validation(ValidationPass):
115+
def name(self) -> str:
116+
return "QASM2 validation"
117+
118+
def run(self, method: ir.Method) -> tuple[Any, list[ir.ValidationError]]:
119+
analysis = _QASM2ValidationAnalysis(method.dialects)
120+
frame, _ = analysis.run(method)
121+
return frame, analysis.get_validation_errors()

src/bloqade/qasm2/emit/target.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from rich.console import Console
55
from kirin.analysis import CallGraph
66
from kirin.dialects import ilist
7+
from kirin.validation import ValidationSuite
78

89
from bloqade.qasm2.parse import ast, pprint
910
from bloqade.qasm2.passes.fold import QASM2Fold
@@ -14,6 +15,7 @@
1415
from . import impls as impls # register the tables
1516
from .gate import EmitQASM2Gate
1617
from .main import EmitQASM2Main
18+
from ..analysis import QASM2Validation
1719

1820

1921
class QASM2:
@@ -114,6 +116,8 @@ def emit(self, entry: ir.Method) -> ast.MainProgram:
114116
# rewrite parallel to uop
115117
ParallelToUOp(dialects=entry.dialects)(entry)
116118

119+
ValidationSuite([QASM2Validation]).validate(entry).raise_if_invalid()
120+
117121
Py2QASM(entry.dialects)(entry)
118122
target_main = EmitQASM2Main(self.main_target).initialize()
119123
target_main.run(entry)

src/bloqade/qasm2/groups.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from kirin import ir, passes
22
from kirin.prelude import structural_no_opt
33
from kirin.dialects import scf, func, ilist, ssacfg, lowering
4+
from kirin.validation import ValidationSuite
45

6+
from bloqade.qasm2.passes import UnrollIfs
7+
from bloqade.qasm2.analysis import QASM2Validation
58
from bloqade.qasm2.dialects import (
69
uop,
710
core,
@@ -64,6 +67,7 @@ def run_pass(
6467
def main(self):
6568
fold_pass = passes.Fold(self)
6669
typeinfer_pass = passes.TypeInfer(self)
70+
unroll_ifs = UnrollIfs(self)
6771

6872
def run_pass(
6973
method: ir.Method,
@@ -78,6 +82,10 @@ def run_pass(
7882

7983
typeinfer_pass(method)
8084
method.verify_type()
85+
unroll_ifs(method)
86+
87+
validation_result = ValidationSuite([QASM2Validation]).validate(method)
88+
validation_result.raise_if_invalid()
8189

8290
return run_pass
8391

test/qasm2/emit/test_qasm2_emit.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
import pytest
2-
from kirin.interp import InterpreterError
3-
41
from bloqade import qasm2
52

63

@@ -240,8 +237,13 @@ def non_empty_else():
240237

241238
target = qasm2.emit.QASM2()
242239

243-
with pytest.raises(InterpreterError):
240+
had_error = False
241+
try:
244242
target.emit(non_empty_else)
243+
except Exception:
244+
# TODO: this is just to work around ExceptionGroup for now
245+
had_error = True
246+
assert had_error
245247

246248
@qasm2.extended
247249
def multiline_then():
@@ -256,8 +258,13 @@ def multiline_then():
256258
return q
257259

258260
target = qasm2.emit.QASM2(unroll_ifs=False)
259-
with pytest.raises(InterpreterError):
261+
had_error = False
262+
try:
260263
target.emit(multiline_then)
264+
except Exception:
265+
# TODO: this is just to work around ExceptionGroup for now
266+
had_error = True
267+
assert had_error
261268

262269
@qasm2.extended
263270
def valid_if():

test/qasm2/test_count.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from bloqade import qasm2
55
from bloqade.analysis.address import (
6-
Unknown,
76
AddressReg,
87
ConstResult,
98
AddressQubit,
@@ -51,27 +50,28 @@ def tuple_count():
5150
assert isinstance(ret.data[1], AddressReg) and ret.data[1].data == range(3, 7)
5251

5352

54-
def test_dynamic_address():
55-
@qasm2.main
56-
def dynamic_address():
57-
ra = qasm2.qreg(3)
58-
rb = qasm2.qreg(4)
59-
ca = qasm2.creg(2)
60-
qasm2.measure(ra[0], ca[0])
61-
qasm2.measure(rb[1], ca[1])
62-
if ca[0] == ca[1]:
63-
ret = ra
64-
else:
65-
ret = rb
66-
67-
return ret
68-
69-
# dynamic_address.code.print()
70-
dynamic_address.print()
71-
fold(dynamic_address)
72-
frame, result = address.run(dynamic_address)
73-
dynamic_address.print(analysis=frame.entries)
74-
assert isinstance(result, Unknown)
53+
# NOTE: this is also invalid for QASM2 - you can't yield from if statements and no else bodies
54+
# def test_dynamic_address():
55+
# @qasm2.main
56+
# def dynamic_address():
57+
# ra = qasm2.qreg(3)
58+
# rb = qasm2.qreg(4)
59+
# ca = qasm2.creg(2)
60+
# qasm2.measure(ra[0], ca[0])
61+
# qasm2.measure(rb[1], ca[1])
62+
# if ca[0] == ca[1]:
63+
# ret = ra
64+
# else:
65+
# ret = rb
66+
67+
# return ret
68+
69+
# # dynamic_address.code.print()
70+
# dynamic_address.print()
71+
# fold(dynamic_address)
72+
# frame, result = address.run(dynamic_address)
73+
# dynamic_address.print(analysis=frame.entries)
74+
# assert isinstance(result, Unknown)
7575

7676

7777
# NOTE: this is invalid for QASM2

0 commit comments

Comments
 (0)