Skip to content

Commit c553a1c

Browse files
committed
Also validate the more lenient if syntax
1 parent 020f975 commit c553a1c

File tree

2 files changed

+61
-13
lines changed

2 files changed

+61
-13
lines changed

src/bloqade/qasm2/analysis/validation/analysis.py

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from kirin.validation import ValidationPass
99
from kirin.analysis.forward import ForwardFrame
1010

11-
from bloqade.qasm2.types import CRegType
11+
from bloqade.qasm2.types import BitType, CRegType
1212
from bloqade.qasm2.dialects.core import CRegEq
1313
from bloqade.qasm2.passes.unroll_if import DontLiftType
1414

@@ -41,8 +41,7 @@ def if_else(
4141
stmt: scf.IfElse,
4242
):
4343

44-
if interp_.strict_if_conditions:
45-
self.__validate_if_condition(interp_, stmt)
44+
self.__validate_if_condition(interp_, stmt, interp_.strict_if_conditions)
4645

4746
if len(stmt.then_body.blocks) > 1:
4847
interp_.add_validation_error(
@@ -92,7 +91,7 @@ def if_else(
9291
self.__validate_empty_yield(interp_, else_stmts[-1])
9392

9493
def __validate_if_condition(
95-
self, interp_: _QASM2ValidationAnalysis, stmt: scf.IfElse
94+
self, interp_: _QASM2ValidationAnalysis, stmt: scf.IfElse, strict: bool
9695
):
9796
cond = stmt.cond
9897
cond_owner = cond.owner
@@ -108,21 +107,57 @@ def __validate_if_condition(
108107
lhs = cond_owner.lhs
109108
rhs = cond_owner.rhs
110109

111-
one_side_is_creg = lhs.type.is_subseteq(CRegType) ^ rhs.type.is_subseteq(
112-
CRegType
113-
)
114-
one_side_is_int = lhs.type.is_subseteq(types.Int) ^ rhs.type.is_subseteq(
115-
types.Int
116-
)
110+
# Guard against bottom types
111+
lhs_is_bottom = lhs.type.is_subseteq(types.Bottom)
112+
rhs_is_bottom = rhs.type.is_subseteq(types.Bottom)
117113

118-
if not (one_side_is_int and one_side_is_creg):
114+
if rhs_is_bottom or lhs_is_bottom:
119115
interp_.add_validation_error(
120116
stmt,
121117
ir.ValidationError(
122-
stmt,
123-
f"Native QASM2 syntax only allows comparing an entire classical register to an integer, but got {lhs} == {rhs}",
118+
stmt, f"Unexpected type in comparison: {lhs} == {rhs}"
124119
),
125120
)
121+
return
122+
123+
if strict:
124+
# NOTE: only allow creg == int according to QASM2 spec
125+
one_side_is_creg = lhs.type.is_subseteq(CRegType) ^ rhs.type.is_subseteq(
126+
CRegType
127+
)
128+
one_side_is_int = lhs.type.is_subseteq(types.Int) ^ rhs.type.is_subseteq(
129+
types.Int
130+
)
131+
132+
if not (one_side_is_int and one_side_is_creg):
133+
interp_.add_validation_error(
134+
stmt,
135+
ir.ValidationError(
136+
stmt,
137+
f"Native QASM2 syntax only allows comparing an entire classical register to an integer, but got {lhs} == {rhs}",
138+
),
139+
)
140+
else:
141+
# NOTE: more lenient syntax: also allow creg1 == creg2, creg[0] == x
142+
check_lhs = (
143+
lhs.type.is_subseteq(CRegType)
144+
or lhs.type.is_subseteq(BitType)
145+
or lhs.type.is_subseteq(types.Int)
146+
)
147+
check_rhs = (
148+
rhs.type.is_subseteq(CRegType)
149+
or rhs.type.is_subseteq(BitType)
150+
or rhs.type.is_subseteq(types.Int)
151+
)
152+
153+
if not check_lhs and check_rhs:
154+
interp_.add_validation_error(
155+
stmt,
156+
ir.ValidationError(
157+
stmt,
158+
f"Expected classical register, bits or integers in if-statement, but got {lhs} == {rhs}",
159+
),
160+
)
126161

127162
def __validate_empty_yield(
128163
self, interp_: _QASM2ValidationAnalysis, stmt: ir.Statement

test/qasm2/test_ifs.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,16 @@ def main_invalid():
111111

112112
with pytest.raises(ValidationErrorGroup):
113113
target.emit(main_invalid)
114+
115+
116+
def test_completely_invalid_if():
117+
with pytest.raises(ValidationErrorGroup):
118+
119+
@qasm2.main
120+
def main():
121+
q = qasm2.qreg(2)
122+
123+
if q[0] == 1:
124+
qasm2.x(q[1])
125+
126+
return q

0 commit comments

Comments
 (0)