|
2 | 2 | from warnings import warn |
3 | 3 | from dataclasses import field, dataclass |
4 | 4 |
|
| 5 | +from bloqade.rewrite.passes import AggressiveUnroll |
5 | 6 | import cirq |
6 | | -from kirin import ir, types, interp |
| 7 | +from kirin import ir, types, interp, passes |
7 | 8 | from kirin.emit import EmitABC, EmitError, EmitFrame |
8 | 9 | from kirin.interp import MethodTable, impl |
9 | 10 | from kirin.passes import inline |
10 | | -from kirin.dialects import func |
| 11 | +from kirin.dialects import func, py |
11 | 12 | from typing_extensions import Self |
12 | 13 |
|
13 | 14 | from bloqade.squin import kernel |
@@ -114,10 +115,49 @@ def main(): |
114 | 115 |
|
115 | 116 | emitter = EmitCirq(qubits=circuit_qubits) |
116 | 117 |
|
117 | | - mt_ = mt.similar(mt.dialects) |
118 | | - inline.InlinePass(mt_.dialects).fixpoint(mt_) |
119 | 118 |
|
120 | | - return emitter.run(mt_, args=args) |
| 119 | + |
| 120 | + symbol_op_trait = mt.code.get_trait(ir.SymbolOpInterface) |
| 121 | + if (symbol_op_trait := mt.code.get_trait(ir.SymbolOpInterface)) is None: |
| 122 | + raise EmitError( |
| 123 | + f"The method is not a symbol, cannot emit circuit!" |
| 124 | + ) |
| 125 | + |
| 126 | + sym_name = symbol_op_trait.get_sym_name(mt.code).unwrap() |
| 127 | + |
| 128 | + if (signature_trait := mt.code.get_trait(ir.HasSignature)) is None: |
| 129 | + raise EmitError( |
| 130 | + f"The method {sym_name} does not have a signature, cannot emit circuit!" |
| 131 | + ) |
| 132 | + |
| 133 | + signature = signature_trait.get_signature(mt.code) |
| 134 | + new_signature = func.Signature(inputs=(), output=signature.output) |
| 135 | + |
| 136 | + callable_region = mt.callable_region.clone() |
| 137 | + entry_block = callable_region.blocks[0] |
| 138 | + args_ssa = list(entry_block.args) |
| 139 | + first_stmt = entry_block.first_stmt |
| 140 | + |
| 141 | + assert first_stmt is not None, "Method has no statements!" |
| 142 | + if len(args_ssa) - 1 != len(args): |
| 143 | + raise EmitError( |
| 144 | + f"The method {sym_name} takes {len(args_ssa)} arguments, but you passed in {len(args)} via the `args` keyword!" |
| 145 | + ) |
| 146 | + |
| 147 | + for arg, arg_ssa in zip(args, args_ssa[1:], strict=True): |
| 148 | + (value := py.Constant(arg)).insert_before(first_stmt) |
| 149 | + arg_ssa.replace_by(value.result) |
| 150 | + entry_block.args.delete(arg_ssa) |
| 151 | + |
| 152 | + new_func = func.Function(sym_name=sym_name, body=callable_region, signature=new_signature) |
| 153 | + mt_ = ir.Method(None, None, sym_name, [], mt.dialects, new_func) |
| 154 | + |
| 155 | + passes.Fold(mt_.dialects, no_raise=False)(mt_) |
| 156 | + mt_.print(hint="const") |
| 157 | + |
| 158 | + # AggressiveUnroll(mt_.dialects)(mt_) |
| 159 | + # mt_.print(hint="const") |
| 160 | + return emitter.run(mt_, args=()) |
121 | 161 |
|
122 | 162 |
|
123 | 163 | @dataclass |
|
0 commit comments