|
4 | 4 | from kirin import ir, types, lowering |
5 | 5 | from kirin.dialects import cf, func, ilist |
6 | 6 |
|
7 | | -from bloqade.qasm2.types import CRegType, QRegType |
| 7 | +from bloqade.qasm2.types import CRegType, QRegType, QubitType |
8 | 8 | from bloqade.qasm2.dialects import uop, core, expr, glob, noise, parallel |
9 | 9 |
|
10 | 10 | from . import ast |
@@ -101,6 +101,13 @@ def lower_literal(self, state: lowering.State[ast.Node], value) -> ir.SSAValue: |
101 | 101 | def lower_global( |
102 | 102 | self, state: lowering.State[ast.Node], node: ast.Node |
103 | 103 | ) -> lowering.LoweringABC.Result: |
| 104 | + if isinstance(node, ast.Name): |
| 105 | + # NOTE: might be a lookup for a gate function invoke |
| 106 | + try: |
| 107 | + return lowering.LoweringABC.Result(state.current_frame.globals[node.id]) |
| 108 | + except KeyError: |
| 109 | + pass |
| 110 | + |
104 | 111 | raise lowering.BuildError("Global variables are not supported in QASM 2.0") |
105 | 112 |
|
106 | 113 | def visit_MainProgram(self, state: lowering.State[ast.Node], node: ast.MainProgram): |
@@ -430,7 +437,56 @@ def visit_Include(self, state: lowering.State[ast.Node], node: ast.Include): |
430 | 437 | raise lowering.BuildError(f"Include {node.filename} not found") |
431 | 438 |
|
432 | 439 | def visit_Gate(self, state: lowering.State[ast.Node], node: ast.Gate): |
433 | | - raise NotImplementedError("Gate lowering not supported") |
| 440 | + arg_names = node.cparams + node.qparams |
| 441 | + arg_types = [types.Float for _ in node.cparams] + [ |
| 442 | + QubitType for _ in node.qparams |
| 443 | + ] |
| 444 | + |
| 445 | + with state.frame( |
| 446 | + stmts=node.body, |
| 447 | + finalize_next=False, |
| 448 | + ) as body_frame: |
| 449 | + # NOTE: insert _self as arg |
| 450 | + body_frame.curr_block.args.append_from( |
| 451 | + types.Generic( |
| 452 | + ir.Method, types.Tuple.where(tuple(arg_types)), types.NoneType |
| 453 | + ), |
| 454 | + name=node.name + "_self", |
| 455 | + ) |
| 456 | + |
| 457 | + for arg_type, arg_name in zip(arg_types, arg_names): |
| 458 | + # NOTE: append args as block arguments |
| 459 | + block_arg = body_frame.curr_block.args.append_from( |
| 460 | + arg_type, name=arg_name |
| 461 | + ) |
| 462 | + |
| 463 | + # NOTE: add arguments as definitions to frame |
| 464 | + body_frame.defs[arg_name] = block_arg |
| 465 | + |
| 466 | + body_frame.exhaust() |
| 467 | + |
| 468 | + # NOTE: append none as return value |
| 469 | + return_val = func.ConstantNone() |
| 470 | + body_frame.push(return_val) |
| 471 | + body_frame.push(func.Return(return_val)) |
| 472 | + |
| 473 | + body = body_frame.curr_region |
| 474 | + |
| 475 | + gate_func = expr.GateFunction( |
| 476 | + sym_name=node.name, |
| 477 | + signature=func.Signature(inputs=tuple(arg_types), output=types.NoneType), |
| 478 | + body=body, |
| 479 | + ) |
| 480 | + |
| 481 | + mt = ir.Method( |
| 482 | + mod=None, |
| 483 | + py_func=None, |
| 484 | + sym_name=node.name, |
| 485 | + dialects=self.dialects, |
| 486 | + arg_names=[*node.cparams, *node.qparams], |
| 487 | + code=gate_func, |
| 488 | + ) |
| 489 | + state.current_frame.globals[node.name] = mt |
434 | 490 |
|
435 | 491 | def visit_Instruction(self, state: lowering.State[ast.Node], node: ast.Instruction): |
436 | 492 | params = [state.lower(param).expect_one() for param in node.params] |
|
0 commit comments