|
2 | 2 | from dataclasses import field, dataclass |
3 | 3 |
|
4 | 4 | from kirin import ir, types, lowering |
5 | | -from kirin.dialects import cf, func, ilist |
| 5 | +from kirin.dialects import cf, scf, func, ilist |
6 | 6 |
|
7 | 7 | from bloqade.qasm2.types import CRegType, QRegType, QubitType |
8 | 8 | from bloqade.qasm2.dialects import uop, core, expr, glob, noise, parallel |
@@ -178,92 +178,30 @@ def visit_UGate(self, state: lowering.State[ast.Node], node: ast.UGate): |
178 | 178 | def visit_Reset(self, state: lowering.State[ast.Node], node: ast.Reset): |
179 | 179 | state.current_frame.push(core.Reset(qarg=state.lower(node.qarg).expect_one())) |
180 | 180 |
|
181 | | - # TODO: clean this up? copied from cf dialect with a small modification |
182 | 181 | def visit_IfStmt(self, state: lowering.State[ast.Node], node: ast.IfStmt): |
183 | 182 | cond_stmt = core.CRegEq( |
184 | 183 | lhs=state.lower(node.cond.lhs).expect_one(), |
185 | 184 | rhs=state.lower(node.cond.rhs).expect_one(), |
186 | 185 | ) |
187 | 186 | cond = state.current_frame.push(cond_stmt).result |
188 | 187 | frame = state.current_frame |
189 | | - before_block = frame.curr_block |
190 | 188 |
|
191 | | - with state.frame(node.body, region=frame.curr_region) as if_frame: |
| 189 | + with state.frame(node.body) as if_frame: |
192 | 190 | true_cond = if_frame.entr_block.args.append_from(types.Bool, cond.name) |
193 | 191 | if cond.name: |
194 | 192 | if_frame.defs[cond.name] = true_cond |
195 | 193 |
|
| 194 | + # NOTE: pass in definitions from outer scope (usually just for the qreg) |
| 195 | + if_frame.defs.update(frame.defs) |
| 196 | + |
196 | 197 | if_frame.exhaust() |
197 | | - self.branch_next_if_not_terminated(if_frame) |
198 | 198 |
|
199 | | - with state.frame([], region=frame.curr_region) as else_frame: |
200 | | - true_cond = else_frame.entr_block.args.append_from(types.Bool, cond.name) |
201 | | - if cond.name: |
202 | | - else_frame.defs[cond.name] = true_cond |
203 | | - else_frame.exhaust() |
204 | | - self.branch_next_if_not_terminated(else_frame) |
205 | | - |
206 | | - with state.frame(frame.stream.split(), region=frame.curr_region) as after_frame: |
207 | | - after_frame.defs.update(frame.defs) |
208 | | - phi: set[str] = set() |
209 | | - for name in if_frame.defs.keys(): |
210 | | - if frame.get(name): |
211 | | - phi.add(name) |
212 | | - elif name in else_frame.defs: |
213 | | - phi.add(name) |
214 | | - |
215 | | - for name in else_frame.defs.keys(): |
216 | | - if frame.get(name): # not defined in if_frame |
217 | | - phi.add(name) |
218 | | - |
219 | | - for name in phi: |
220 | | - after_frame.defs[name] = after_frame.entr_block.args.append_from( |
221 | | - types.Any, name |
222 | | - ) |
| 199 | + # NOTE: qasm2 can never yield anything from if |
| 200 | + if_frame.push(scf.Yield()) |
223 | 201 |
|
224 | | - after_frame.exhaust() |
225 | | - self.branch_next_if_not_terminated(after_frame) |
226 | | - after_frame.next_block.stmts.append( |
227 | | - cf.Branch(arguments=(), successor=frame.next_block) |
228 | | - ) |
| 202 | + then_body = if_frame.curr_region |
229 | 203 |
|
230 | | - if_args = [] |
231 | | - for name in phi: |
232 | | - if value := if_frame.get(name): |
233 | | - if_args.append(value) |
234 | | - else: |
235 | | - raise lowering.BuildError(f"undefined variable {name} in if branch") |
236 | | - |
237 | | - else_args = [] |
238 | | - for name in phi: |
239 | | - if value := else_frame.get(name): |
240 | | - else_args.append(value) |
241 | | - else: |
242 | | - raise lowering.BuildError(f"undefined variable {name} in else branch") |
243 | | - |
244 | | - if_frame.next_block.stmts.append( |
245 | | - cf.Branch( |
246 | | - arguments=tuple(if_args), |
247 | | - successor=after_frame.entr_block, |
248 | | - ) |
249 | | - ) |
250 | | - else_frame.next_block.stmts.append( |
251 | | - cf.Branch( |
252 | | - arguments=tuple(else_args), |
253 | | - successor=after_frame.entr_block, |
254 | | - ) |
255 | | - ) |
256 | | - before_block.stmts.append( |
257 | | - cf.ConditionalBranch( |
258 | | - cond=cond, |
259 | | - then_arguments=(cond,), |
260 | | - then_successor=if_frame.entr_block, |
261 | | - else_arguments=(cond,), |
262 | | - else_successor=else_frame.entr_block, |
263 | | - ) |
264 | | - ) |
265 | | - frame.defs.update(after_frame.defs) |
266 | | - frame.jump_next_block() |
| 204 | + state.current_frame.push(scf.IfElse(cond, then_body=then_body)) |
267 | 205 |
|
268 | 206 | def branch_next_if_not_terminated(self, frame: lowering.Frame): |
269 | 207 | """Branch to the next block if the current block is not terminated. |
|
0 commit comments