|
5 | 5 | import cirq |
6 | 6 | from kirin import ir, lowering |
7 | 7 | from kirin.rewrite import Walk, CFGCompactify |
8 | | -from kirin.dialects import py, ilist |
| 8 | +from kirin.dialects import py, scf, ilist |
9 | 9 |
|
10 | 10 | from .. import op, noise, qubit |
11 | 11 |
|
@@ -150,10 +150,79 @@ def lower_measurement( |
150 | 150 | ): |
151 | 151 | if len(node.qubits) == 1: |
152 | 152 | qbit = self.lower_qubit_getindex(state, node.qubits[0]) |
153 | | - return state.current_frame.push(qubit.MeasureQubit(qbit)) |
| 153 | + stmt = state.current_frame.push(qubit.MeasureQubit(qbit)) |
| 154 | + else: |
| 155 | + qbits = self.lower_qubit_getindices(state, node.qubits) |
| 156 | + stmt = state.current_frame.push(qubit.MeasureQubitList(qbits)) |
154 | 157 |
|
155 | | - qbits = self.lower_qubit_getindices(state, node.qubits) |
156 | | - return state.current_frame.push(qubit.MeasureQubitList(qbits)) |
| 158 | + key = node.gate.key |
| 159 | + if isinstance(key, cirq.MeasurementKey): |
| 160 | + key = key.name |
| 161 | + |
| 162 | + state.current_frame.defs[key] = stmt.result |
| 163 | + return stmt |
| 164 | + |
| 165 | + def visit_ClassicallyControlledOperation( |
| 166 | + self, state: lowering.State[CirqNode], node: cirq.ClassicallyControlledOperation |
| 167 | + ): |
| 168 | + conditions: list[ir.SSAValue] = [] |
| 169 | + for outcome in node.classical_controls: |
| 170 | + key = outcome.key |
| 171 | + if isinstance(key, cirq.MeasurementKey): |
| 172 | + key = key.name |
| 173 | + measurement_outcome = state.current_frame.defs[key] |
| 174 | + |
| 175 | + if measurement_outcome.type.is_subseteq(ilist.IListType): |
| 176 | + # NOTE: there is currently no convenient ilist.any method, so we need to use foldl |
| 177 | + # with a simple function that just does an or |
| 178 | + |
| 179 | + def bool_op_or(x: bool, y: bool) -> bool: |
| 180 | + return x or y |
| 181 | + |
| 182 | + f_code = state.current_frame.push( |
| 183 | + lowering.Python(self.dialects).python_function(bool_op_or) |
| 184 | + ) |
| 185 | + fn = ir.Method( |
| 186 | + mod=None, |
| 187 | + py_func=bool_op_or, |
| 188 | + sym_name="bool_op_or", |
| 189 | + arg_names=[], |
| 190 | + dialects=self.dialects, |
| 191 | + code=f_code, |
| 192 | + ) |
| 193 | + f_const = state.current_frame.push(py.constant.Constant(fn)) |
| 194 | + init_val = state.current_frame.push(py.Constant(False)).result |
| 195 | + condition = state.current_frame.push( |
| 196 | + ilist.Foldl(f_const.result, measurement_outcome, init=init_val) |
| 197 | + ).result |
| 198 | + else: |
| 199 | + condition = measurement_outcome |
| 200 | + |
| 201 | + conditions.append(condition) |
| 202 | + |
| 203 | + if len(conditions) == 1: |
| 204 | + condition = conditions[0] |
| 205 | + else: |
| 206 | + condition = state.current_frame.push( |
| 207 | + py.boolop.And(conditions[0], conditions[1]) |
| 208 | + ).result |
| 209 | + for next_cond in conditions[2:]: |
| 210 | + condition = state.current_frame.push( |
| 211 | + py.boolop.And(condition, next_cond) |
| 212 | + ).result |
| 213 | + |
| 214 | + then_stmt = self.visit(state, node.without_classical_controls()) |
| 215 | + |
| 216 | + assert isinstance( |
| 217 | + then_stmt, ir.Statement |
| 218 | + ), f"Expected operation of classically controlled node {node} to be lowered to a statement, got type {type(then_stmt)}. \ |
| 219 | + Please report this issue!" |
| 220 | + |
| 221 | + # NOTE: remove stmt from parent block |
| 222 | + then_stmt.detach() |
| 223 | + then_body = ir.Block((then_stmt,)) |
| 224 | + |
| 225 | + return state.current_frame.push(scf.IfElse(condition, then_body=then_body)) |
157 | 226 |
|
158 | 227 | def visit_SingleQubitPauliStringGateOperation( |
159 | 228 | self, |
|
0 commit comments