|
6 | 6 | from kirin.ir.dialect import Dialect as Dialect |
7 | 7 | from typing_extensions import Self |
8 | 8 |
|
| 9 | +from bloqade.types import QubitType |
9 | 10 | from bloqade.qasm2.parse import ast |
10 | 11 | from bloqade.qasm2.dialects.uop import SingleQubitGate, TwoQubitCtrlGate |
11 | 12 | from bloqade.qasm2.dialects.expr import GateFunction |
@@ -105,27 +106,40 @@ def emit_func( |
105 | 106 | with gate_emitter.new_frame( |
106 | 107 | callable_node, has_parent_access=False |
107 | 108 | ) as gate_frame: |
108 | | - gate_result = gate_emitter.frame_eval(gate_frame, callable_node) |
109 | | - gate_obj = None |
110 | | - if isinstance(gate_result, tuple) and len(gate_result) > 0: |
111 | | - maybe = gate_result[0] |
112 | | - if isinstance(maybe, ast.Gate): |
113 | | - gate_obj = maybe |
114 | | - |
115 | | - if gate_obj is None: |
116 | | - name = emit.callables.get( |
117 | | - callable_node |
118 | | - ) or emit.callables.add(callable_node) |
119 | | - prefix = getattr(emit.callables, "prefix", "") or "" |
120 | | - emit_name = ( |
121 | | - name[len(prefix) :] |
122 | | - if prefix and name.startswith(prefix) |
123 | | - else name |
124 | | - ) |
125 | | - gate_obj = ast.Gate( |
126 | | - name=emit_name, cparams=[], qparams=[], body=[] |
| 109 | + args: list[ast.Node] = [] |
| 110 | + cparams, qparams = [], [] |
| 111 | + entry_args = callable_node.body.blocks[0].args |
| 112 | + user_args = entry_args[1:] if len(entry_args) > 0 else [] |
| 113 | + |
| 114 | + for arg in user_args: |
| 115 | + assert arg.name is not None |
| 116 | + |
| 117 | + args.append(ast.Name(id=arg.name)) |
| 118 | + if arg.type.is_subseteq(QubitType): |
| 119 | + qparams.append(arg.name) |
| 120 | + else: |
| 121 | + cparams.append(arg.name) |
| 122 | + |
| 123 | + # Map block arguments to AST names in the gate frame |
| 124 | + for arg in user_args: |
| 125 | + gate_frame.set( |
| 126 | + arg, ast.Name(id=getattr(arg, "name", "arg")) |
127 | 127 | ) |
128 | 128 |
|
| 129 | + # Actually emit the gate body by interpreting all blocks |
| 130 | + for block in callable_node.body.blocks: |
| 131 | + gate_emitter.emit_block(gate_frame, block) |
| 132 | + |
| 133 | + name = emit.callables.get(callable_node) or emit.callables.add( |
| 134 | + callable_node |
| 135 | + ) |
| 136 | + gate_obj = ast.Gate( |
| 137 | + name=name, |
| 138 | + cparams=cparams, |
| 139 | + qparams=qparams, |
| 140 | + body=gate_frame.body, |
| 141 | + ) |
| 142 | + |
129 | 143 | gate_defs.append(gate_obj) |
130 | 144 |
|
131 | 145 | if emit.dialects.data.intersection((parallel.dialect, glob.dialect)): |
|
0 commit comments