88from kirin .validation import ValidationPass
99from kirin .analysis .forward import ForwardFrame
1010
11- from bloqade .qasm2 .types import CRegType
11+ from bloqade .qasm2 .types import BitType , CRegType
1212from bloqade .qasm2 .dialects .core import CRegEq
1313from bloqade .qasm2 .passes .unroll_if import DontLiftType
1414
@@ -41,8 +41,7 @@ def if_else(
4141 stmt : scf .IfElse ,
4242 ):
4343
44- if interp_ .strict_if_conditions :
45- self .__validate_if_condition (interp_ , stmt )
44+ self .__validate_if_condition (interp_ , stmt , interp_ .strict_if_conditions )
4645
4746 if len (stmt .then_body .blocks ) > 1 :
4847 interp_ .add_validation_error (
@@ -92,7 +91,7 @@ def if_else(
9291 self .__validate_empty_yield (interp_ , else_stmts [- 1 ])
9392
9493 def __validate_if_condition (
95- self , interp_ : _QASM2ValidationAnalysis , stmt : scf .IfElse
94+ self , interp_ : _QASM2ValidationAnalysis , stmt : scf .IfElse , strict : bool
9695 ):
9796 cond = stmt .cond
9897 cond_owner = cond .owner
@@ -108,21 +107,57 @@ def __validate_if_condition(
108107 lhs = cond_owner .lhs
109108 rhs = cond_owner .rhs
110109
111- one_side_is_creg = lhs .type .is_subseteq (CRegType ) ^ rhs .type .is_subseteq (
112- CRegType
113- )
114- one_side_is_int = lhs .type .is_subseteq (types .Int ) ^ rhs .type .is_subseteq (
115- types .Int
116- )
110+ # Guard against bottom types
111+ lhs_is_bottom = lhs .type .is_subseteq (types .Bottom )
112+ rhs_is_bottom = rhs .type .is_subseteq (types .Bottom )
117113
118- if not ( one_side_is_int and one_side_is_creg ) :
114+ if rhs_is_bottom or lhs_is_bottom :
119115 interp_ .add_validation_error (
120116 stmt ,
121117 ir .ValidationError (
122- stmt ,
123- f"Native QASM2 syntax only allows comparing an entire classical register to an integer, but got { lhs } == { rhs } " ,
118+ stmt , f"Unexpected type in comparison: { lhs } == { rhs } "
124119 ),
125120 )
121+ return
122+
123+ if strict :
124+ # NOTE: only allow creg == int according to QASM2 spec
125+ one_side_is_creg = lhs .type .is_subseteq (CRegType ) ^ rhs .type .is_subseteq (
126+ CRegType
127+ )
128+ one_side_is_int = lhs .type .is_subseteq (types .Int ) ^ rhs .type .is_subseteq (
129+ types .Int
130+ )
131+
132+ if not (one_side_is_int and one_side_is_creg ):
133+ interp_ .add_validation_error (
134+ stmt ,
135+ ir .ValidationError (
136+ stmt ,
137+ f"Native QASM2 syntax only allows comparing an entire classical register to an integer, but got { lhs } == { rhs } " ,
138+ ),
139+ )
140+ else :
141+ # NOTE: more lenient syntax: also allow creg1 == creg2, creg[0] == x
142+ check_lhs = (
143+ lhs .type .is_subseteq (CRegType )
144+ or lhs .type .is_subseteq (BitType )
145+ or lhs .type .is_subseteq (types .Int )
146+ )
147+ check_rhs = (
148+ rhs .type .is_subseteq (CRegType )
149+ or rhs .type .is_subseteq (BitType )
150+ or rhs .type .is_subseteq (types .Int )
151+ )
152+
153+ if not check_lhs and check_rhs :
154+ interp_ .add_validation_error (
155+ stmt ,
156+ ir .ValidationError (
157+ stmt ,
158+ f"Expected classical register, bits or integers in if-statement, but got { lhs } == { rhs } " ,
159+ ),
160+ )
126161
127162 def __validate_empty_yield (
128163 self , interp_ : _QASM2ValidationAnalysis , stmt : ir .Statement
0 commit comments