diff --git a/src/bloqade/analysis/address/analysis.py b/src/bloqade/analysis/address/analysis.py index 1f78cfdb0..921f4b3b5 100644 --- a/src/bloqade/analysis/address/analysis.py +++ b/src/bloqade/analysis/address/analysis.py @@ -15,7 +15,7 @@ class AddressAnalysis(Forward[Address]): This analysis pass can be used to track the global addresses of qubits and wires. """ - keys = ["qubit.address"] + keys = ("qubit.address",) _const_prop: const.Propagate lattice = Address next_address: int = field(init=False) @@ -45,7 +45,7 @@ def try_eval_const_prop( ) -> interp.StatementResult[Address]: _frame = self._const_prop.initialize_frame(frame.code) _frame.set_values(stmt.args, tuple(x.result for x in args)) - result = self._const_prop.eval_stmt(_frame, stmt) + result = self._const_prop.frame_eval(_frame, stmt) match result: case interp.ReturnValue(constant_ret): @@ -96,7 +96,8 @@ def run_lattice( self, callee: Address, inputs: tuple[Address, ...], - kwargs: tuple[str, ...], + keys: tuple[str, ...], + kwargs: tuple[Address, ...], ) -> Address: """Run a callable lattice element with the given inputs and keyword arguments. @@ -111,15 +112,16 @@ def run_lattice( """ match callee: - case PartialLambda(code=code, argnames=argnames): - _, ret = self.run_callable( - code, (callee,) + self.permute_values(argnames, inputs, kwargs) + case PartialLambda(code=code): + _, ret = self.call( + code, callee, *inputs, **{k: v for k, v in zip(keys, kwargs)} ) - return ret case ConstResult(const.Value(ir.Method() as method)): - _, ret = self.run_method( - method, - self.permute_values(method.arg_names, inputs, kwargs), + _, ret = self.call( + method.code, + self.method_self(method), + *inputs, + **{k: v for k, v in zip(keys, kwargs)}, ) return ret case _: @@ -137,14 +139,12 @@ def get_const_value(self, addr: Address, typ: Type[T]) -> T | None: return value - def eval_stmt_fallback(self, frame: ForwardFrame[Address], stmt: ir.Statement): - args = frame.get_values(stmt.args) + def eval_fallback(self, frame: ForwardFrame[Address], node: ir.Statement): + args = frame.get_values(node.args) if types.is_tuple_of(args, ConstResult): - return self.try_eval_const_prop(frame, stmt, args) + return self.try_eval_const_prop(frame, node, args) - return tuple(Address.from_type(result.type) for result in stmt.results) + return tuple(Address.from_type(result.type) for result in node.results) - def run_method(self, method: ir.Method, args: tuple[Address, ...]): - # NOTE: we do not support dynamic calls here, thus no need to propagate method object - self_mt = ConstResult(const.Value(method)) - return self.run_callable(method.code, (self_mt,) + args) + def method_self(self, method: ir.Method) -> Address: + return ConstResult(const.Value(method)) diff --git a/src/bloqade/analysis/address/impls.py b/src/bloqade/analysis/address/impls.py index 1a89bb3e2..d7932986b 100644 --- a/src/bloqade/analysis/address/impls.py +++ b/src/bloqade/analysis/address/impls.py @@ -97,7 +97,7 @@ def map_( results = [] for ele in values: - ret = interp_.run_lattice(fn, (ele,), ()) + ret = interp_.run_lattice(fn, (ele,), (), ()) results.append(ret) if isinstance(stmt, ilist.Map): @@ -180,13 +180,10 @@ def invoke( frame: ForwardFrame[Address], stmt: func.Invoke, ): - - args = interp_.permute_values( - stmt.callee.arg_names, frame.get_values(stmt.inputs), stmt.kwargs - ) - _, ret = interp_.run_method( - stmt.callee, - args, + _, ret = interp_.call( + stmt.callee.code, + interp_.method_self(stmt.callee), + *frame.get_values(stmt.inputs), ) return (ret,) @@ -219,7 +216,8 @@ def call( result = interp_.run_lattice( frame.get(stmt.callee), frame.get_values(stmt.inputs), - stmt.kwargs, + stmt.keys, + frame.get_values(stmt.kwargs), ) return (result,) @@ -319,26 +317,28 @@ def ifelse( ): body = stmt.then_body if const_cond.data else stmt.else_body with interp_.new_frame(stmt, has_parent_access=True) as body_frame: - ret = interp_.run_ssacfg_region(body_frame, body, (address_cond,)) + ret = interp_.frame_call_region(body_frame, stmt, body, address_cond) # interp_.set_values(frame, body_frame.entries.keys(), body_frame.entries.values()) return ret else: # run both branches with interp_.new_frame(stmt, has_parent_access=True) as then_frame: - then_results = interp_.run_ssacfg_region( - then_frame, stmt.then_body, (address_cond,) - ) - interp_.set_values( - frame, then_frame.entries.keys(), then_frame.entries.values() + then_results = interp_.frame_call_region( + then_frame, + stmt, + stmt.then_body, + address_cond, ) + frame.set_values(then_frame.entries.keys(), then_frame.entries.values()) with interp_.new_frame(stmt, has_parent_access=True) as else_frame: - else_results = interp_.run_ssacfg_region( - else_frame, stmt.else_body, (address_cond,) - ) - interp_.set_values( - frame, else_frame.entries.keys(), else_frame.entries.values() + else_results = interp_.frame_call_region( + else_frame, + stmt, + stmt.else_body, + address_cond, ) + frame.set_values(else_frame.entries.keys(), else_frame.entries.values()) # TODO: pick the non-return value if isinstance(then_results, interp.ReturnValue) and isinstance( else_results, interp.ReturnValue @@ -364,12 +364,12 @@ def for_loop( iter_type, iterable = interp_.unpack_iterable(frame.get(stmt.iterable)) if iter_type is None: - return interp_.eval_stmt_fallback(frame, stmt) + return interp_.eval_fallback(frame, stmt) for value in iterable: with interp_.new_frame(stmt, has_parent_access=True) as body_frame: - loop_vars = interp_.run_ssacfg_region( - body_frame, stmt.body, (value,) + loop_vars + loop_vars = interp_.frame_call_region( + body_frame, stmt, stmt.body, value, *loop_vars ) if loop_vars is None: diff --git a/src/bloqade/analysis/fidelity/analysis.py b/src/bloqade/analysis/fidelity/analysis.py index f1ad252f3..815b57256 100644 --- a/src/bloqade/analysis/fidelity/analysis.py +++ b/src/bloqade/analysis/fidelity/analysis.py @@ -4,7 +4,6 @@ from kirin import ir from kirin.lattice import EmptyLattice from kirin.analysis import Forward -from kirin.interp.value import Successor from kirin.analysis.forward import ForwardFrame from ..address import Address, AddressAnalysis @@ -48,15 +47,11 @@ def main(): The fidelity of the gate set described by the analysed program. It reduces whenever a noise channel is encountered. """ - _current_gate_fidelity: float = field(init=False) - atom_survival_probability: list[float] = field(init=False) """ The probabilities that each of the atoms in the register survive the duration of the analysed program. The order of the list follows the order they are in the register. """ - _current_atom_survival_probability: list[float] = field(init=False) - addr_frame: ForwardFrame[Address] = field(init=False) def initialize(self): @@ -67,25 +62,15 @@ def initialize(self): ] return self - def posthook_succ(self, frame: ForwardFrame, succ: Successor): - self.gate_fidelity *= self._current_gate_fidelity - for i, _current_survival in enumerate(self._current_atom_survival_probability): - self.atom_survival_probability[i] *= _current_survival - - def eval_stmt_fallback(self, frame: ForwardFrame, stmt: ir.Statement): + def eval_fallback(self, frame: ForwardFrame, node: ir.Statement): # NOTE: default is to conserve fidelity, so do nothing here return - def run_method(self, method: ir.Method, args: tuple[EmptyLattice, ...]): - return self.run_callable(method.code, (self.lattice.bottom(),) + args) - - def run_analysis( - self, method: ir.Method, args: tuple | None = None, *, no_raise: bool = True - ) -> tuple[ForwardFrame, Any]: - self._run_address_analysis(method, no_raise=no_raise) - return super().run(method) + def run(self, method: ir.Method, *args, **kwargs) -> tuple[ForwardFrame, Any]: + self._run_address_analysis(method) + return super().run(method, *args, **kwargs) - def _run_address_analysis(self, method: ir.Method, no_raise: bool): + def _run_address_analysis(self, method: ir.Method): addr_analysis = AddressAnalysis(self.dialects) addr_frame, _ = addr_analysis.run(method=method) self.addr_frame = addr_frame diff --git a/src/bloqade/analysis/measure_id/analysis.py b/src/bloqade/analysis/measure_id/analysis.py index f2d5e9f3f..8b65b2f38 100644 --- a/src/bloqade/analysis/measure_id/analysis.py +++ b/src/bloqade/analysis/measure_id/analysis.py @@ -22,20 +22,16 @@ class MeasurementIDAnalysis(ForwardExtra[MeasureIDFrame, MeasureId]): measure_count = 0 def initialize_frame( - self, code: ir.Statement, *, has_parent_access: bool = False + self, node: ir.Statement, *, has_parent_access: bool = False ) -> MeasureIDFrame: - return MeasureIDFrame(code, has_parent_access=has_parent_access) + return MeasureIDFrame(node, has_parent_access=has_parent_access) # Still default to bottom, # but let constants return the softer "NoMeasureId" type from impl - def eval_stmt_fallback( - self, frame: ForwardFrame[MeasureId], stmt: ir.Statement + def eval_fallback( + self, frame: ForwardFrame[MeasureId], node: ir.Statement ) -> tuple[MeasureId, ...]: - return tuple(NotMeasureId() for _ in stmt.results) - - def run_method(self, method: ir.Method, args: tuple[MeasureId, ...]): - # NOTE: we do not support dynamic calls here, thus no need to propagate method object - return self.run_callable(method.code, (self.lattice.bottom(),) + args) + return tuple(NotMeasureId() for _ in node.results) # Xiu-zhe (Roger) Luo came up with this in the address analysis, # reused here for convenience (now modified to be a bit more graceful) @@ -45,7 +41,7 @@ def run_method(self, method: ir.Method, args: tuple[MeasureId, ...]): T = TypeVar("T") def get_const_value( - self, input_type: type[T], value: ir.SSAValue + self, input_type: type[T] | tuple[type[T], ...], value: ir.SSAValue ) -> type[T] | None: if isinstance(hint := value.hints.get("const"), const.Value): data = hint.data diff --git a/src/bloqade/analysis/measure_id/impls.py b/src/bloqade/analysis/measure_id/impls.py index 993b97bd3..439ae2a6d 100644 --- a/src/bloqade/analysis/measure_id/impls.py +++ b/src/bloqade/analysis/measure_id/impls.py @@ -138,11 +138,10 @@ def return_(self, _: MeasurementIDAnalysis, frame: interp.Frame, stmt: func.Retu def invoke( self, interp_: MeasurementIDAnalysis, frame: interp.Frame, stmt: func.Invoke ): - _, ret = interp_.run_method( - stmt.callee, - interp_.permute_values( - stmt.callee.arg_names, frame.get_values(stmt.inputs), stmt.kwargs - ), + _, ret = interp_.call( + stmt.callee.code, + interp_.method_self(stmt.callee), + *frame.get_values(stmt.inputs), ) return (ret,) diff --git a/src/bloqade/cirq_utils/emit/base.py b/src/bloqade/cirq_utils/emit/base.py index 199e49dc8..4c831d81f 100644 --- a/src/bloqade/cirq_utils/emit/base.py +++ b/src/bloqade/cirq_utils/emit/base.py @@ -189,39 +189,8 @@ def initialize_frame( node, has_parent_access=has_parent_access, qubits=self.qubits ) - def run_method(self, method: ir.Method, args: tuple[cirq.Circuit, ...]): - return self.call(method, *args) - - def run_callable_region( - self, - frame: EmitCirqFrame, - code: ir.Statement, - region: ir.Region, - args: tuple, - ): - if len(region.blocks) > 0: - block_args = list(region.blocks[0].args) - # NOTE: skip self arg - frame.set_values(block_args[1:], args) - - results = self.frame_eval(frame, code) - if isinstance(results, tuple): - if len(results) == 0: - return self.void - elif len(results) == 1: - return results[0] - raise interp.InterpreterError(f"Unexpected results {results}") - - def emit_block(self, frame: EmitCirqFrame, block: ir.Block) -> cirq.Circuit: - for stmt in block.stmts: - result = self.frame_eval(frame, stmt) - if isinstance(result, tuple): - frame.set_values(stmt.results, result) - - return self.circuit - def reset(self): - pass + self.circuit = cirq.Circuit() def eval_fallback(self, frame: EmitCirqFrame, node: ir.Statement) -> tuple: return tuple(None for _ in range(len(node.results))) diff --git a/src/bloqade/cirq_utils/lowering.py b/src/bloqade/cirq_utils/lowering.py index 7d81b45bc..fc6d895ab 100644 --- a/src/bloqade/cirq_utils/lowering.py +++ b/src/bloqade/cirq_utils/lowering.py @@ -99,7 +99,7 @@ def main(): ``` """ - target = Squin(dialects=dialects, circuit=circuit) + target = Squin(dialects, circuit) body = target.run( circuit, source=str(circuit), # TODO: proper source string @@ -144,8 +144,6 @@ def main(): ) mt = ir.Method( - mod=None, - py_func=None, sym_name=kernel_name, arg_names=arg_names, dialects=dialects, diff --git a/src/bloqade/gemini/analysis/logical_validation/analysis.py b/src/bloqade/gemini/analysis/logical_validation/analysis.py index 14a03cbf8..4cf2d7a23 100644 --- a/src/bloqade/gemini/analysis/logical_validation/analysis.py +++ b/src/bloqade/gemini/analysis/logical_validation/analysis.py @@ -9,9 +9,9 @@ class GeminiLogicalValidationAnalysis(ValidationAnalysis): first_gate = True - def eval_stmt_fallback(self, frame: ValidationFrame, stmt: ir.Statement): - if isinstance(stmt, squin.gate.stmts.Gate): + def eval_fallback(self, frame: ValidationFrame, node: ir.Statement): + if isinstance(node, squin.gate.stmts.Gate): # NOTE: to validate that only the first encountered gate can be non-Clifford, we need to track this here self.first_gate = False - return super().eval_stmt_fallback(frame, stmt) + return super().eval_fallback(frame, node) diff --git a/src/bloqade/native/upstream/squin2native.py b/src/bloqade/native/upstream/squin2native.py index 2a9131f28..998d34b68 100644 --- a/src/bloqade/native/upstream/squin2native.py +++ b/src/bloqade/native/upstream/squin2native.py @@ -62,16 +62,18 @@ def emit(self, mt: ir.Method, *, no_raise=True) -> ir.Method: all_dialects = chain.from_iterable( ker.dialects.data for kers in old_callgraph.defs.values() for ker in kers ) - new_dialects = ( - mt.dialects.union(all_dialects).discard(gate_dialect).union(kernel) - ) + combined_dialects = mt.dialects.union(all_dialects).union(kernel) - out = mt.similar(new_dialects) - UpdateDialectsOnCallGraph(new_dialects, no_raise=no_raise)(out) - CallGraphPass(new_dialects, rewrite.Walk(GateRule()), no_raise=no_raise)(out) - # verify all kernels in the callgraph + out = mt.similar(combined_dialects) + UpdateDialectsOnCallGraph(combined_dialects, no_raise=no_raise)(out) + CallGraphPass(combined_dialects, rewrite.Walk(GateRule()), no_raise=no_raise)( + out + ) + # verify all kernels in the callgraph and discard gate dialect + out.dialects.discard(gate_dialect) new_callgraph = CallGraph(out) for ker in new_callgraph.edges.keys(): + ker.dialects.discard(gate_dialect) ker.verify() return out diff --git a/src/bloqade/pyqrack/target.py b/src/bloqade/pyqrack/target.py index e9f212339..54700933a 100644 --- a/src/bloqade/pyqrack/target.py +++ b/src/bloqade/pyqrack/target.py @@ -87,7 +87,8 @@ def run( """ fold = Fold(mt.dialects) fold(mt) - return self._get_interp(mt).run(mt, args, kwargs) + _, ret = self._get_interp(mt).run(mt, *args, **kwargs) + return ret def multi_run( self, diff --git a/src/bloqade/pyqrack/task.py b/src/bloqade/pyqrack/task.py index 1502f430d..0acb6ef0d 100644 --- a/src/bloqade/pyqrack/task.py +++ b/src/bloqade/pyqrack/task.py @@ -24,14 +24,12 @@ class PyQrackSimulatorTask(AbstractSimulatorTask[Param, RetType, MemoryType]): pyqrack_interp: PyQrackInterpreter[MemoryType] def run(self) -> RetType: - return cast( - RetType, - self.pyqrack_interp.run( - self.kernel, - args=self.args, - kwargs=self.kwargs, - ), + _, ret = self.pyqrack_interp.run( + self.kernel, + *self.args, + **self.kwargs, ) + return cast(RetType, ret) @property def state(self) -> MemoryType: diff --git a/src/bloqade/qasm2/_qasm_loading.py b/src/bloqade/qasm2/_qasm_loading.py index 63ffcd5f9..57ee5815b 100644 --- a/src/bloqade/qasm2/_qasm_loading.py +++ b/src/bloqade/qasm2/_qasm_loading.py @@ -4,6 +4,7 @@ from typing import Any from kirin import ir, lowering +from kirin.types import MethodType from kirin.dialects import func from . import parse @@ -82,11 +83,10 @@ def loads( body=body, ) + body.blocks[0].args.append_from(MethodType, kernel_name + "_self") + mt = ir.Method( - mod=None, - py_func=None, sym_name=kernel_name, - arg_names=[], dialects=qasm2_lowering.dialects, code=code, ) diff --git a/src/bloqade/qasm2/dialects/expr/stmts.py b/src/bloqade/qasm2/dialects/expr/stmts.py index e2e130e68..fad08b339 100644 --- a/src/bloqade/qasm2/dialects/expr/stmts.py +++ b/src/bloqade/qasm2/dialects/expr/stmts.py @@ -87,7 +87,7 @@ def print_impl(self, printer: Printer) -> None: # QASM 2.0 arithmetic operations -PyNum = types.Union(types.Int, types.Float) +PyNum = types.TypeVar("PyNum", bound=types.Union(types.Int, types.Float)) @statement(dialect=dialect) @@ -110,7 +110,7 @@ class Sin(ir.Statement): traits = frozenset({lowering.FromPythonCall()}) value: ir.SSAValue = info.argument(PyNum) """value (Union[int, float]): The number to take the sine of.""" - result: ir.ResultValue = info.result(PyNum) + result: ir.ResultValue = info.result(types.Float) """result (float): The sine of the number.""" @@ -122,7 +122,7 @@ class Cos(ir.Statement): traits = frozenset({lowering.FromPythonCall()}) value: ir.SSAValue = info.argument(PyNum) """value (Union[int, float]): The number to take the cosine of.""" - result: ir.ResultValue = info.result(PyNum) + result: ir.ResultValue = info.result(types.Float) """result (float): The cosine of the number.""" @@ -134,7 +134,7 @@ class Tan(ir.Statement): traits = frozenset({lowering.FromPythonCall()}) value: ir.SSAValue = info.argument(PyNum) """value (Union[int, float]): The number to take the tangent of.""" - result: ir.ResultValue = info.result(PyNum) + result: ir.ResultValue = info.result(types.Float) """result (float): The tangent of the number.""" @@ -146,7 +146,7 @@ class Exp(ir.Statement): traits = frozenset({lowering.FromPythonCall()}) value: ir.SSAValue = info.argument(PyNum) """value (Union[int, float]): The number to take the exponential of.""" - result: ir.ResultValue = info.result(PyNum) + result: ir.ResultValue = info.result(types.Float) """result (float): The exponential of the number.""" @@ -158,7 +158,7 @@ class Log(ir.Statement): traits = frozenset({lowering.FromPythonCall()}) value: ir.SSAValue = info.argument(PyNum) """value (Union[int, float]): The number to take the natural log of.""" - result: ir.ResultValue = info.result(PyNum) + result: ir.ResultValue = info.result(types.Float) """result (float): The natural log of the number.""" @@ -170,7 +170,7 @@ class Sqrt(ir.Statement): traits = frozenset({lowering.FromPythonCall()}) value: ir.SSAValue = info.argument(PyNum) """value (Union[int, float]): The number to take the square root of.""" - result: ir.ResultValue = info.result(PyNum) + result: ir.ResultValue = info.result(types.Float) """result (float): The square root of the number.""" diff --git a/src/bloqade/qasm2/dialects/noise/fidelity.py b/src/bloqade/qasm2/dialects/noise/fidelity.py index f7ed75c40..acd17ac9c 100644 --- a/src/bloqade/qasm2/dialects/noise/fidelity.py +++ b/src/bloqade/qasm2/dialects/noise/fidelity.py @@ -31,7 +31,7 @@ def pauli_channel( # NOTE: fidelity is just the inverse probability of any noise to occur fid = (1 - p) * (1 - p_ctrl) - interp._current_gate_fidelity *= fid + interp.gate_fidelity *= fid @interp.impl(AtomLossChannel) def atom_loss( @@ -44,4 +44,4 @@ def atom_loss( addresses = interp.addr_frame.get(stmt.qargs) # NOTE: get the corresponding index and reduce survival probability accordingly for index in addresses.data: - interp._current_atom_survival_probability[index] *= 1 - stmt.prob + interp.atom_survival_probability[index] *= 1 - stmt.prob diff --git a/src/bloqade/qasm2/emit/base.py b/src/bloqade/qasm2/emit/base.py index 4f7fba1d1..4fac32172 100644 --- a/src/bloqade/qasm2/emit/base.py +++ b/src/bloqade/qasm2/emit/base.py @@ -45,11 +45,6 @@ def initialize_frame( ) -> EmitQASM2Frame[StmtType]: return EmitQASM2Frame(node, has_parent_access=has_parent_access) - def run_method( - self, method: ir.Method, args: tuple[ast.Node | None, ...] - ) -> tuple[EmitQASM2Frame[StmtType], ast.Node | None]: - return self.call(method, *args) - def emit_block(self, frame: EmitQASM2Frame, block: ir.Block) -> ast.Node | None: for stmt in block.stmts: result = self.frame_eval(frame, stmt) diff --git a/src/bloqade/qasm2/emit/target.py b/src/bloqade/qasm2/emit/target.py index a2548bba8..784034a43 100644 --- a/src/bloqade/qasm2/emit/target.py +++ b/src/bloqade/qasm2/emit/target.py @@ -106,13 +106,13 @@ def emit(self, entry: ir.Method) -> ast.MainProgram: unroll_ifs=self.unroll_ifs, ).fixpoint(entry) - # if not self.allow_global: - # # rewrite global to parallel - # GlobalToParallel(dialects=entry.dialects)(entry) + if not self.allow_global: + # rewrite global to parallel + GlobalToParallel(dialects=entry.dialects)(entry) - # if not self.allow_parallel: - # # rewrite parallel to uop - # ParallelToUOp(dialects=entry.dialects)(entry) + if not self.allow_parallel: + # rewrite parallel to uop + ParallelToUOp(dialects=entry.dialects)(entry) Py2QASM(entry.dialects)(entry) target_main = EmitQASM2Main(self.main_target).initialize() diff --git a/src/bloqade/qasm2/parse/lowering.py b/src/bloqade/qasm2/parse/lowering.py index 765d1eb3c..07d66d28c 100644 --- a/src/bloqade/qasm2/parse/lowering.py +++ b/src/bloqade/qasm2/parse/lowering.py @@ -450,7 +450,6 @@ def visit_Instruction(self, state: lowering.State[ast.Node], node: ast.Instructi func.Invoke( callee=value, inputs=tuple(params + qargs), - kwargs=tuple(), ) ) diff --git a/src/bloqade/qasm2/rewrite/uop_to_parallel.py b/src/bloqade/qasm2/rewrite/uop_to_parallel.py index b1c102ed4..5c85376bf 100644 --- a/src/bloqade/qasm2/rewrite/uop_to_parallel.py +++ b/src/bloqade/qasm2/rewrite/uop_to_parallel.py @@ -66,7 +66,7 @@ def same_id_checker(ssa1: ir.SSAValue, ssa2: ir.SSAValue): assert isinstance(hint1, lattice.Result) and isinstance( hint2, lattice.Result ) - return hint1.is_equal(hint2) + return hint1.is_structurally_equal(hint2) else: return False diff --git a/src/bloqade/qbraid/lowering.py b/src/bloqade/qbraid/lowering.py index 159580710..276e62fff 100644 --- a/src/bloqade/qbraid/lowering.py +++ b/src/bloqade/qbraid/lowering.py @@ -320,5 +320,6 @@ def lower_full_turns(self, value: float) -> ir.SSAValue: self.block_list.append(const_pi) turns = self.lower_number(2 * value) mul = qasm2.expr.Mul(const_pi.result, turns) + mul.result.type = types.Float self.block_list.append(mul) return mul.result diff --git a/src/bloqade/qbraid/schema.py b/src/bloqade/qbraid/schema.py index 54eed1c20..450d4f5a5 100644 --- a/src/bloqade/qbraid/schema.py +++ b/src/bloqade/qbraid/schema.py @@ -238,13 +238,13 @@ def decompiled_circuit(self) -> str: str: The decompiled circuit from hardware execution. """ - from bloqade.noise import native from bloqade.qasm2.emit import QASM2 from bloqade.qasm2.passes import glob, parallel + from bloqade.qasm2.rewrite.noise import remove_noise mt = self.lower_noise_model("method") - native.RemoveNoisePass(mt.dialects)(mt) + remove_noise.RemoveNoisePass(mt.dialects)(mt) parallel.ParallelToUOp(mt.dialects)(mt) glob.GlobalToUOP(mt.dialects)(mt) return QASM2(qelib1=True).emit_str(mt) diff --git a/src/bloqade/squin/analysis/schedule.py b/src/bloqade/squin/analysis/schedule.py index e99e219da..35487e083 100644 --- a/src/bloqade/squin/analysis/schedule.py +++ b/src/bloqade/squin/analysis/schedule.py @@ -185,18 +185,17 @@ def push_current_dag(self, block: ir.Block): self.stmt_dag = StmtDag() self.use_def = {} - def run_method(self, method: ir.Method, args: tuple[GateSchedule, ...]): - # NOTE: we do not support dynamic calls here, thus no need to propagate method object - return self.run_callable(method.code, (self.lattice.bottom(),) + args) + def method_self(self, method: ir.Method) -> GateSchedule: + return self.lattice.bottom() - def eval_stmt_fallback(self, frame: ForwardFrame, stmt: ir.Statement): - if stmt.has_trait(ir.IsTerminator): + def eval_fallback(self, frame: ForwardFrame, node: ir.Statement): + if node.has_trait(ir.IsTerminator): assert ( - stmt.parent_block is not None + node.parent_block is not None ), "Terminator statement has no parent block" - self.push_current_dag(stmt.parent_block) + self.push_current_dag(node.parent_block) - return tuple(self.lattice.top() for _ in stmt.results) + return tuple(self.lattice.top() for _ in node.results) def _update_dag(self, stmt: ir.Statement, addr: address.Address): if isinstance(addr, address.AddressQubit): diff --git a/src/bloqade/validation/analysis/analysis.py b/src/bloqade/validation/analysis/analysis.py index 323cbd40a..19220b73d 100644 --- a/src/bloqade/validation/analysis/analysis.py +++ b/src/bloqade/validation/analysis/analysis.py @@ -28,14 +28,14 @@ class ValidationAnalysis(ForwardExtra[ValidationFrame, ErrorType], ABC): lattice = ErrorType - def run_method(self, method: ir.Method, args: tuple[ErrorType, ...]): - return self.run_callable(method.code, (self.lattice.top(),) + args) - - def eval_stmt_fallback(self, frame: ValidationFrame, stmt: ir.Statement): + def eval_fallback(self, frame: ValidationFrame, node: ir.Statement): # NOTE: default to no errors - return tuple(self.lattice.top() for _ in stmt.results) + return tuple(self.lattice.top() for _ in node.results) def initialize_frame( - self, code: ir.Statement, *, has_parent_access: bool = False + self, node: ir.Statement, *, has_parent_access: bool = False ) -> ValidationFrame: - return ValidationFrame(code, has_parent_access=has_parent_access) + return ValidationFrame(node, has_parent_access=has_parent_access) + + def method_self(self, method: ir.Method) -> ErrorType: + return self.lattice.top() diff --git a/src/bloqade/validation/kernel_validation.py b/src/bloqade/validation/kernel_validation.py index 841593520..d3f82774b 100644 --- a/src/bloqade/validation/kernel_validation.py +++ b/src/bloqade/validation/kernel_validation.py @@ -48,9 +48,23 @@ class KernelValidation: validation_analysis_cls: type[ValidationAnalysis] """The analysis that you want to run in order to validate the kernel.""" - def run(self, mt: ir.Method, **kwargs) -> None: + def run(self, mt: ir.Method, no_raise: bool = True) -> None: + """Run the kernel validation analysis and raise any errors found. + + Args: + mt (ir.Method): The method to validate + no_raise (bool): Whether or not to raise errors when running the analysis. + This is only to make sure the analysis works. Errors found during + the analysis will be raised regardless of this setting. Defaults to `True`. + + """ + validation_analysis = self.validation_analysis_cls(mt.dialects) - validation_frame, _ = validation_analysis.run_analysis(mt, **kwargs) + + if no_raise: + validation_frame, _ = validation_analysis.run_no_raise(mt) + else: + validation_frame, _ = validation_analysis.run(mt) errors = validation_frame.errors diff --git a/test/analysis/address/test_qubit_analysis.py b/test/analysis/address/test_qubit_analysis.py index 1866c9a11..dddf825a4 100644 --- a/test/analysis/address/test_qubit_analysis.py +++ b/test/analysis/address/test_qubit_analysis.py @@ -21,7 +21,7 @@ def test(): return (q1[1], q2) address_analysis = address.AddressAnalysis(test.dialects) - frame, _ = address_analysis.run_analysis(test, no_raise=False) + frame, _ = address_analysis.run(test) address_types = collect_address_types(frame, address.PartialTuple) test.print(analysis=frame.entries) @@ -116,7 +116,7 @@ def main(): return q address_analysis = address.AddressAnalysis(main.dialects) - address_analysis.run_analysis(main, no_raise=False) + address_analysis.run(main) def test_new_qubit(): @@ -125,7 +125,7 @@ def main(): return squin.qubit.new() address_analysis = address.AddressAnalysis(main.dialects) - _, result = address_analysis.run_analysis(main, no_raise=False) + _, result = address_analysis.run(main) assert result == address.AddressQubit(0) @@ -139,8 +139,9 @@ def main(n: int): return qreg address_analysis = address.AddressAnalysis(main.dialects) - frame, result = address_analysis.run_analysis( - main, args=(address.ConstResult(const.Unknown()),), no_raise=False + frame, result = address_analysis.run( + main, + address.ConstResult(const.Unknown()), ) assert result == address.AddressReg(data=tuple(range(4))) @@ -155,7 +156,7 @@ def main(n: int): return qreg address_analysis = address.AddressAnalysis(main.dialects) - frame, result = address_analysis.run_analysis(main, no_raise=False) + frame, result = address_analysis.run(main) assert result == address.AddressReg(data=tuple(range(4))) @@ -165,7 +166,7 @@ def main(n: int): return (0, 1) + (2, n) address_analysis = address.AddressAnalysis(main.dialects) - frame, result = address_analysis.run_analysis(main, no_raise=False) + frame, result = address_analysis.run(main) assert result == address.PartialTuple( data=( @@ -183,7 +184,7 @@ def main(n: int): return (0, 1) + [2, n] # type: ignore address_analysis = address.AddressAnalysis(main.dialects) - frame, result = address_analysis.run_analysis(main, no_raise=False) + frame, result = address_analysis.run(main) assert result == address.Bottom() @@ -194,7 +195,7 @@ def main(n: tuple[int, ...]): return (0, 1) + n address_analysis = address.AddressAnalysis(main.dialects) - frame, result = address_analysis.run_analysis(main, no_raise=False) + frame, result = address_analysis.run(main) assert result == address.Unknown() @@ -207,7 +208,7 @@ def main(q: qubit.Qubit): return (0, q, 2, q)[1::2] address_analysis = address.AddressAnalysis(main.dialects) - frame, result = address_analysis.run_analysis(main, no_raise=False) + frame, result = address_analysis.run(main) assert result == address.UnknownReg() @@ -219,7 +220,7 @@ def main(n: int): main.print() address_analysis = address.AddressAnalysis(main.dialects) - frame, result = address_analysis.run_analysis(main, no_raise=False) + frame, result = address_analysis.run(main) main.print(analysis=frame.entries) assert ( result == address.UnknownReg() @@ -260,7 +261,7 @@ def main(): func = main analysis = address.AddressAnalysis(squin.kernel) - _, ret = analysis.run_analysis(func, no_raise=False) + _, ret = analysis.run(func) assert ret == address.AddressReg(data=tuple(range(20))) assert analysis.qubit_count == 20 diff --git a/test/analysis/fidelity/test_fidelity.py b/test/analysis/fidelity/test_fidelity.py index cbb398455..78ca8f26c 100644 --- a/test/analysis/fidelity/test_fidelity.py +++ b/test/analysis/fidelity/test_fidelity.py @@ -19,7 +19,7 @@ def main(): return q fid_analysis = FidelityAnalysis(main.dialects) - fid_analysis.run_analysis(main, no_raise=False) + fid_analysis.run(main) assert fid_analysis.gate_fidelity == fid_analysis._current_gate_fidelity == 1 assert fid_analysis.atom_survival_probability[0] == 1 - p_loss @@ -49,11 +49,10 @@ def main(): main.print() fid_analysis = FidelityAnalysis(main.dialects) - fid_analysis.run_analysis(main, no_raise=False) + fid_analysis.run(main) expected_fidelity = (1 - 3 * p_ch) ** 2 - assert fid_analysis.gate_fidelity == fid_analysis._current_gate_fidelity assert math.isclose(fid_analysis.gate_fidelity, expected_fidelity) @@ -69,11 +68,10 @@ def main(): main.print() fid_analysis = FidelityAnalysis(main.dialects) - fid_analysis.run_analysis(main, no_raise=False) + fid_analysis.run(main) expected_fidelity = 1 - 3 * p_ch - assert fid_analysis.gate_fidelity == fid_analysis._current_gate_fidelity assert math.isclose(fid_analysis.gate_fidelity, expected_fidelity) @@ -123,12 +121,12 @@ def main_if(): ) NoisePass(main.dialects, noise_model=model)(main) fid_analysis = FidelityAnalysis(main.dialects) - fid_analysis.run_analysis(main, no_raise=False) + fid_analysis.run(main) model = NoiseTestModel() NoisePass(main_if.dialects, noise_model=model)(main_if) fid_if_analysis = FidelityAnalysis(main_if.dialects) - fid_if_analysis.run_analysis(main_if, no_raise=False) + fid_if_analysis.run(main_if) assert 0 < fid_if_analysis.gate_fidelity == fid_analysis.gate_fidelity < 1 assert ( @@ -186,7 +184,7 @@ def main_for(): ) NoisePass(main.dialects, noise_model=model)(main) fid_analysis = FidelityAnalysis(main.dialects) - fid_analysis.run_analysis(main, no_raise=False) + fid_analysis.run(main) model = NoiseTestModel() NoisePass(main_for.dialects, noise_model=model)(main_for) @@ -194,7 +192,7 @@ def main_for(): main_for.print() fid_for_analysis = FidelityAnalysis(main_for.dialects) - fid_for_analysis.run_analysis(main_for, no_raise=False) + fid_for_analysis.run(main_for) assert 0 < fid_for_analysis.gate_fidelity == fid_analysis.gate_fidelity < 1 assert ( diff --git a/test/cirq_utils/test_cirq_to_squin.py b/test/cirq_utils/test_cirq_to_squin.py index 101cd3be8..479ba0519 100644 --- a/test/cirq_utils/test_cirq_to_squin.py +++ b/test/cirq_utils/test_cirq_to_squin.py @@ -250,9 +250,9 @@ def test_nesting_lowered_circuit(): @squin.kernel def main(): qreg = get_entangled_qubits() - qreg2 = squin.squin.qalloc(1) + qreg2 = squin.qalloc(1) entangle_qubits([qreg[1], qreg2[0]]) - return squin.qubit.measure(qreg2) + return squin.broadcast.measure(qreg2) # if you get up to here, the validation works main.print() diff --git a/test/gemini/test_logical_validation.py b/test/gemini/test_logical_validation.py index ce8e1a34d..046b2ac42 100644 --- a/test/gemini/test_logical_validation.py +++ b/test/gemini/test_logical_validation.py @@ -30,16 +30,14 @@ def main(): if m2: squin.y(q[2]) - frame, _ = GeminiLogicalValidationAnalysis(main.dialects).run_analysis( - main, no_raise=False - ) + frame, _ = GeminiLogicalValidationAnalysis(main.dialects).run_no_raise(main) main.print(analysis=frame.entries) validator = KernelValidation(GeminiLogicalValidationAnalysis) with pytest.raises(ValidationErrorGroup): - validator.run(main) + validator.run(main, no_raise=False) def test_for_loop(): @@ -104,8 +102,8 @@ def invalid(): squin.cx(q[0], q[1]) squin.u3(0.123, 0.253, 1.2, q[0]) - frame, _ = GeminiLogicalValidationAnalysis(invalid.dialects).run_analysis( - invalid, no_raise=False + frame, _ = GeminiLogicalValidationAnalysis(invalid.dialects).run_no_raise( + invalid ) invalid.print(analysis=frame.entries) diff --git a/test/pyqrack/runtime/noise/qasm2/test_loss.py b/test/pyqrack/runtime/noise/qasm2/test_loss.py index 4c1d19a14..29214c095 100644 --- a/test/pyqrack/runtime/noise/qasm2/test_loss.py +++ b/test/pyqrack/runtime/noise/qasm2/test_loss.py @@ -1,12 +1,10 @@ -from typing import Literal from unittest.mock import Mock from kirin import ir -from kirin.dialects import ilist from bloqade import qasm2 from bloqade.qasm2 import noise -from bloqade.pyqrack import PyQrackQubit, PyQrackInterpreter, reg +from bloqade.pyqrack import PyQrackInterpreter, reg from bloqade.pyqrack.base import MockMemory @@ -34,9 +32,9 @@ def test_atom_loss(c: qasm2.CReg): input = reg.CRegister(1) memory = MockMemory() - result: ilist.IList[PyQrackQubit, Literal[2]] = PyQrackInterpreter( + _, result = PyQrackInterpreter( qasm2.extended, memory=memory, rng_state=rng_state - ).run(test_atom_loss, (input,)) + ).run(test_atom_loss, input) assert result[0].state is reg.QubitState.Lost assert result[1].state is reg.QubitState.Active diff --git a/test/pyqrack/runtime/noise/qasm2/test_pauli.py b/test/pyqrack/runtime/noise/qasm2/test_pauli.py index 04541d231..2f9f2bb9c 100644 --- a/test/pyqrack/runtime/noise/qasm2/test_pauli.py +++ b/test/pyqrack/runtime/noise/qasm2/test_pauli.py @@ -11,7 +11,7 @@ def run_mock(program: ir.Method, rng_state: Mock | None = None): PyQrackInterpreter( program.dialects, memory=(memory := MockMemory()), rng_state=rng_state - ).run(program, ()) + ).run(program) assert isinstance(mock := memory.sim_reg, Mock) return mock diff --git a/test/pyqrack/runtime/test_qrack.py b/test/pyqrack/runtime/test_qrack.py index 9161cabfe..e61747974 100644 --- a/test/pyqrack/runtime/test_qrack.py +++ b/test/pyqrack/runtime/test_qrack.py @@ -14,7 +14,7 @@ def run_mock(program: ir.Method, rng_state: Mock | None = None): PyQrackInterpreter( program.dialects, memory=(memory := MockMemory()), rng_state=rng_state - ).run(program, ()) + ).run(program) assert isinstance(mock := memory.sim_reg, Mock) return mock diff --git a/test/qasm2/emit/test_qasm2_emit.py b/test/qasm2/emit/test_qasm2_emit.py index 00eaaec45..f7e776eb7 100644 --- a/test/qasm2/emit/test_qasm2_emit.py +++ b/test/qasm2/emit/test_qasm2_emit.py @@ -54,7 +54,6 @@ def glob_u(): ) -@pytest.mark.xfail(reason="Global and Parallel refactor still ongoing") def test_global(): @qasm2.extended @@ -85,7 +84,6 @@ def glob_u(): ) -@pytest.mark.xfail(reason="Global and Parallel refactor still ongoing") def test_global_allow_para(): @qasm2.extended @@ -118,7 +116,6 @@ def glob_u(): ) -@pytest.mark.xfail(reason="Global and Parallel refactor still ongoing") def test_para(): @qasm2.extended @@ -145,7 +142,6 @@ def para_u(): ) -@pytest.mark.xfail(reason="Global and Parallel refactor still ongoing") def test_para_allow_para(): @qasm2.extended @@ -201,7 +197,6 @@ def para_u(): ) -@pytest.mark.xfail(reason="Global and Parallel refactor still ongoing") def test_para_allow_global(): @qasm2.extended diff --git a/test/qasm2/passes/test_global_to_parallel.py b/test/qasm2/passes/test_global_to_parallel.py index b77a3a133..3b147636a 100644 --- a/test/qasm2/passes/test_global_to_parallel.py +++ b/test/qasm2/passes/test_global_to_parallel.py @@ -1,6 +1,5 @@ from typing import List -import pytest from kirin import ir, types from kirin.rewrite import Walk, Fixpoint, CommonSubexpressionElimination from kirin.dialects import py, func, ilist @@ -18,7 +17,6 @@ def as_float(value: float): return py.constant.Constant(value=value) -@pytest.mark.xfail def test_global2para_rewrite(): @qasm2.extended @@ -79,7 +77,6 @@ def main(): assert_methods(expected_method, main) -@pytest.mark.xfail def test_global2para_rewrite2(): @qasm2.extended diff --git a/test/qasm2/passes/test_global_to_uop.py b/test/qasm2/passes/test_global_to_uop.py index 9be187d9a..dafe6f2ae 100644 --- a/test/qasm2/passes/test_global_to_uop.py +++ b/test/qasm2/passes/test_global_to_uop.py @@ -1,6 +1,5 @@ from typing import List -import pytest from kirin import ir, types from kirin.rewrite import Walk, Fixpoint, CommonSubexpressionElimination from kirin.dialects import py, func @@ -18,7 +17,6 @@ def as_float(value: float): return py.constant.Constant(value=value) -@pytest.mark.xfail def test_global_rewrite(): @qasm2.extended diff --git a/test/qasm2/passes/test_heuristic_noise.py b/test/qasm2/passes/test_heuristic_noise.py index 20a4c5b0e..78879a5e4 100644 --- a/test/qasm2/passes/test_heuristic_noise.py +++ b/test/qasm2/passes/test_heuristic_noise.py @@ -1,4 +1,3 @@ -import pytest from kirin import ir, types from kirin.dialects import func, ilist from kirin.dialects.py import constant @@ -256,7 +255,6 @@ def test_parallel_cz_gate_noise(): assert_nodes(block, expected_block) -@pytest.mark.xfail def test_global_noise(): @qasm2.extended diff --git a/test/qasm2/passes/test_parallel_to_global.py b/test/qasm2/passes/test_parallel_to_global.py index 93fbac7f8..c72533b84 100644 --- a/test/qasm2/passes/test_parallel_to_global.py +++ b/test/qasm2/passes/test_parallel_to_global.py @@ -1,10 +1,7 @@ -import pytest - from bloqade import qasm2 from bloqade.qasm2.passes.parallel import ParallelToGlobal -@pytest.mark.xfail def test_basic_rewrite(): @qasm2.extended @@ -32,7 +29,6 @@ def main(): ) -@pytest.mark.xfail def test_if_rewrite(): @qasm2.extended def main(): @@ -67,7 +63,6 @@ def main(): ) -@pytest.mark.xfail def test_should_not_be_rewritten(): @qasm2.extended @@ -93,7 +88,6 @@ def main(): ) -@pytest.mark.xfail def test_multiple_registers(): @qasm2.extended def main(): @@ -126,7 +120,6 @@ def main(): ) -@pytest.mark.xfail def test_reverse_order(): @qasm2.extended def main(): diff --git a/test/qasm2/passes/test_parallel_to_uop.py b/test/qasm2/passes/test_parallel_to_uop.py index 7484542ee..c3e2c59e6 100644 --- a/test/qasm2/passes/test_parallel_to_uop.py +++ b/test/qasm2/passes/test_parallel_to_uop.py @@ -1,6 +1,5 @@ from typing import List -import pytest from kirin import ir, types from kirin.dialects import py, func @@ -17,7 +16,6 @@ def as_float(value: float): return py.constant.Constant(value=value) -@pytest.mark.xfail def test_cz_rewrite(): @qasm2.extended diff --git a/test/qasm2/passes/test_uop_to_parallel.py b/test/qasm2/passes/test_uop_to_parallel.py index 7e1f3bdc7..29231d395 100644 --- a/test/qasm2/passes/test_uop_to_parallel.py +++ b/test/qasm2/passes/test_uop_to_parallel.py @@ -1,5 +1,3 @@ -import pytest - from bloqade import qasm2 from bloqade.qasm2 import glob from bloqade.analysis import address @@ -7,7 +5,6 @@ from bloqade.qasm2.rewrite import SimpleOptimalMergePolicy -@pytest.mark.xfail def test_one(): @qasm2.gate @@ -50,7 +47,6 @@ def test(): ) -@pytest.mark.xfail def test_two(): @qasm2.extended @@ -89,7 +85,6 @@ def test(): _, _ = address.AddressAnalysis(test.dialects).run(test) -@pytest.mark.xfail def test_three(): @qasm2.extended diff --git a/test/qasm2/test_count.py b/test/qasm2/test_count.py index df5cef317..32aed54bd 100644 --- a/test/qasm2/test_count.py +++ b/test/qasm2/test_count.py @@ -1,4 +1,3 @@ -import pytest from kirin import passes from kirin.dialects import py, ilist @@ -16,7 +15,6 @@ fold = passes.Fold(qasm2.main.add(py.tuple).add(ilist)) -@pytest.mark.xfail def test_fixed_count(): @qasm2.main def fixed_count(): @@ -36,7 +34,6 @@ def fixed_count(): assert address.qubit_count == 7 -@pytest.mark.xfail def test_multiple_return_only_reg(): @qasm2.main.add(py.tuple) @@ -54,7 +51,6 @@ def tuple_count(): assert isinstance(ret.data[1], AddressReg) and ret.data[1].data == range(3, 7) -@pytest.mark.xfail def test_dynamic_address(): @qasm2.main def dynamic_address(): @@ -64,14 +60,16 @@ def dynamic_address(): qasm2.measure(ra[0], ca[0]) qasm2.measure(rb[1], ca[1]) if ca[0] == ca[1]: - return ra + ret = ra else: - return rb + ret = rb + + return ret # dynamic_address.code.print() dynamic_address.print() fold(dynamic_address) - frame, result = address.run_analysis(dynamic_address) + frame, result = address.run(dynamic_address) dynamic_address.print(analysis=frame.entries) assert isinstance(result, Unknown) @@ -92,7 +90,6 @@ def dynamic_address(): # assert isinstance(result, ConstResult) -@pytest.mark.xfail def test_multi_return(): @qasm2.main.add(py.tuple) def multi_return_cnt(): @@ -110,7 +107,6 @@ def multi_return_cnt(): assert isinstance(result.data[2], AddressReg) -@pytest.mark.xfail def test_list(): @qasm2.main.add(ilist) def list_count_analy(): @@ -120,12 +116,11 @@ def list_count_analy(): return f list_count_analy.code.print() - _, ret = address.run_analysis(list_count_analy) + _, ret = address.run(list_count_analy) assert ret == AddressReg(data=(0, 1, 3)) -@pytest.mark.xfail def test_tuple_qubits(): @qasm2.main.add(py.tuple) def list_count_analy2(): @@ -136,7 +131,7 @@ def list_count_analy2(): return f list_count_analy2.code.print() - _, ret = address.run_analysis(list_count_analy2) + _, ret = address.run(list_count_analy2) assert isinstance(ret, PartialTuple) assert isinstance(ret.data[0], AddressQubit) and ret.data[0].data == 0 assert isinstance(ret.data[1], AddressQubit) and ret.data[1].data == 1 @@ -166,7 +161,6 @@ def list_count_analy2(): # assert isinstance(result.data[5], AddressQubit) and result.data[5].data == 6 -@pytest.mark.xfail def test_alias(): @qasm2.main diff --git a/test/qasm2/test_lowering.py b/test/qasm2/test_lowering.py index 6eee1509b..617d71541 100644 --- a/test/qasm2/test_lowering.py +++ b/test/qasm2/test_lowering.py @@ -3,7 +3,6 @@ import tempfile import textwrap -import pytest from kirin import ir, types from kirin.dialects import func @@ -26,14 +25,12 @@ ) -@pytest.mark.xfail def test_run_lowering(): ast = qasm2.parse.loads(lines) code = QASM2(qasm2.main).run(ast) code.print() -@pytest.mark.xfail def test_loadfile(): with tempfile.TemporaryDirectory() as tmp_dir: @@ -44,7 +41,6 @@ def test_loadfile(): qasm2.loadfile(file) -@pytest.mark.xfail def test_negative_lowering(): mwe = """ @@ -84,7 +80,6 @@ def test_negative_lowering(): assert entry.code.is_structurally_equal(code) -@pytest.mark.xfail def test_gate(): qasm2_prog = textwrap.dedent( """ @@ -113,7 +108,6 @@ def test_gate(): assert math.isclose(abs(ket[3]) ** 2, 0.5, abs_tol=1e-6) -@pytest.mark.xfail def test_gate_with_params(): qasm2_prog = textwrap.dedent( """ @@ -144,7 +138,6 @@ def test_gate_with_params(): assert math.isclose(abs(ket[3]) ** 2, 0.5, abs_tol=1e-6) -@pytest.mark.xfail def test_if_lowering(): qasm2_prog = textwrap.dedent( diff --git a/test/qasm2/test_native.py b/test/qasm2/test_native.py index 15e6ebcb9..fdfaf64d5 100644 --- a/test/qasm2/test_native.py +++ b/test/qasm2/test_native.py @@ -3,7 +3,6 @@ import cirq import numpy as np -import pytest import cirq.testing import cirq.contrib.qasm_import as qasm_import import cirq.circuits.qasm_output as qasm_output @@ -158,7 +157,6 @@ def kernel(): assert new_qasm2.count("\n") > prog.count("\n") -@pytest.mark.xfail def test_ccx_rewrite(): @qasm2.extended diff --git a/test/qbraid/test_lowering.py b/test/qbraid/test_lowering.py index 323972a89..cabca44ed 100644 --- a/test/qbraid/test_lowering.py +++ b/test/qbraid/test_lowering.py @@ -56,11 +56,8 @@ def run_assert(noise_model: schema.NoiseModel, expected_stmts: List[ir.Statement ) expected_mt = ir.Method( - mod=None, - py_func=None, dialects=lowering.qbraid_noise, sym_name="test", - arg_names=[], code=expected_func_stmt, ) @@ -242,7 +239,10 @@ def test_lowering_global_w(): (lam_num := as_float(2 * -(0.5 + phi_val))), (lam := qasm2.expr.Mul(pi_lam.result, lam_num.result)), parallel.UGate( - theta=theta.result, phi=phi.result, lam=lam.result, qargs=qargs.result + theta=ir.ResultValue(theta, 0, type=types.Float), + phi=ir.ResultValue(phi, 0, type=types.Float), + lam=ir.ResultValue(lam, 0, type=types.Float), + qargs=qargs.result, ), func.Return(creg.result), ] @@ -304,7 +304,10 @@ def test_lowering_local_w(): (lam_num := as_float(2 * -(0.5 + phi_val))), (lam := qasm2.expr.Mul(pi_lam.result, lam_num.result)), parallel.UGate( - qargs=qargs.result, theta=theta.result, phi=phi.result, lam=lam.result + qargs=qargs.result, + theta=ir.ResultValue(theta, 0, type=types.Float), + phi=ir.ResultValue(phi, 0, type=types.Float), + lam=ir.ResultValue(lam, 0, type=types.Float), ), func.Return(creg.result), ] @@ -348,7 +351,9 @@ def test_lowering_global_rz(): (theta_pi := qasm2.expr.ConstPI()), (theta_num := as_float(2 * phi_val)), (theta := qasm2.expr.Mul(theta_pi.result, theta_num.result)), - parallel.RZ(theta=theta.result, qargs=qargs.result), + parallel.RZ( + theta=ir.ResultValue(theta, 0, type=types.Float), qargs=qargs.result + ), func.Return(creg.result), ] @@ -401,7 +406,9 @@ def test_lowering_local_rz(): (theta_pi := qasm2.expr.ConstPI()), (theta_num := as_float(2 * phi_val)), (theta := qasm2.expr.Mul(theta_pi.result, theta_num.result)), - parallel.RZ(theta=theta.result, qargs=qargs.result), + parallel.RZ( + theta=ir.ResultValue(theta, 0, type=types.Float), qargs=qargs.result + ), func.Return(creg.result), ] diff --git a/test/squin/test_qubit.py b/test/squin/test_qubit.py index bc16f4cc3..001b92502 100644 --- a/test/squin/test_qubit.py +++ b/test/squin/test_qubit.py @@ -19,7 +19,7 @@ def main(): main.print() assert main.return_type.is_subseteq(types.Int) - @squin.kernel + @squin.kernel(fold=False) def main2(): q = squin.qalloc(2) @@ -34,10 +34,10 @@ def main2(): if m1_id != 0: # do something that errors - squin.x(q[4]) + q[0] + 1 if m2_id != 1: - squin.x(q[4]) + q[0] + 1 return squin.broadcast.measure(q) diff --git a/test/stim/passes/stim_reference_programs/debug/debug.stim b/test/stim/passes/stim_reference_programs/debug/debug.stim index 7479e928f..d4d0923cd 100644 --- a/test/stim/passes/stim_reference_programs/debug/debug.stim +++ b/test/stim/passes/stim_reference_programs/debug/debug.stim @@ -1,2 +1 @@ - # debug message diff --git a/test/stim/passes/stim_reference_programs/qubit/for_loop_nontrivial_index.stim b/test/stim/passes/stim_reference_programs/qubit/for_loop_nontrivial_index.stim index 5081b8918..7dc34014a 100644 --- a/test/stim/passes/stim_reference_programs/qubit/for_loop_nontrivial_index.stim +++ b/test/stim/passes/stim_reference_programs/qubit/for_loop_nontrivial_index.stim @@ -1,4 +1,3 @@ - H 0 CX 0 1 CX 1 2 diff --git a/test/stim/passes/stim_reference_programs/qubit/nested_for_loop.stim b/test/stim/passes/stim_reference_programs/qubit/nested_for_loop.stim index 26adce106..f3c3c462b 100644 --- a/test/stim/passes/stim_reference_programs/qubit/nested_for_loop.stim +++ b/test/stim/passes/stim_reference_programs/qubit/nested_for_loop.stim @@ -1,4 +1,3 @@ - H 0 CX 0 2 CX 1 2 diff --git a/test/stim/passes/stim_reference_programs/qubit/nested_list.stim b/test/stim/passes/stim_reference_programs/qubit/nested_list.stim index 37920f352..0c58c050a 100644 --- a/test/stim/passes/stim_reference_programs/qubit/nested_list.stim +++ b/test/stim/passes/stim_reference_programs/qubit/nested_list.stim @@ -1,3 +1,2 @@ - H 0 H 2 diff --git a/test/stim/passes/stim_reference_programs/qubit/non_pure_loop_iterator.stim b/test/stim/passes/stim_reference_programs/qubit/non_pure_loop_iterator.stim index 9088e2756..ccb35ef0e 100644 --- a/test/stim/passes/stim_reference_programs/qubit/non_pure_loop_iterator.stim +++ b/test/stim/passes/stim_reference_programs/qubit/non_pure_loop_iterator.stim @@ -1,4 +1,3 @@ - MZ(0.00000000) 0 1 2 3 4 X 0 X 1 diff --git a/test/stim/passes/stim_reference_programs/qubit/pick_if_else.stim b/test/stim/passes/stim_reference_programs/qubit/pick_if_else.stim index 05e9c2796..f0285a3d5 100644 --- a/test/stim/passes/stim_reference_programs/qubit/pick_if_else.stim +++ b/test/stim/passes/stim_reference_programs/qubit/pick_if_else.stim @@ -1,2 +1 @@ - H 1 diff --git a/test/stim/passes/stim_reference_programs/qubit/qubit.stim b/test/stim/passes/stim_reference_programs/qubit/qubit.stim index 17873714b..51cc860cb 100644 --- a/test/stim/passes/stim_reference_programs/qubit/qubit.stim +++ b/test/stim/passes/stim_reference_programs/qubit/qubit.stim @@ -1,4 +1,3 @@ - H 0 1 X 0 CX 0 1 diff --git a/test/stim/passes/stim_reference_programs/qubit/qubit_broadcast.stim b/test/stim/passes/stim_reference_programs/qubit/qubit_broadcast.stim index 708cb2a0c..e033406d6 100644 --- a/test/stim/passes/stim_reference_programs/qubit/qubit_broadcast.stim +++ b/test/stim/passes/stim_reference_programs/qubit/qubit_broadcast.stim @@ -1,3 +1,2 @@ - H 0 1 2 3 MZ(0.00000000) 0 1 2 3 diff --git a/test/stim/passes/stim_reference_programs/qubit/qubit_loss.stim b/test/stim/passes/stim_reference_programs/qubit/qubit_loss.stim index 4dee3da23..62b6bf1b3 100644 --- a/test/stim/passes/stim_reference_programs/qubit/qubit_loss.stim +++ b/test/stim/passes/stim_reference_programs/qubit/qubit_loss.stim @@ -1,4 +1,3 @@ - H 0 1 2 3 4 I_ERROR[loss](0.10000000) 3 I_ERROR[loss](0.05000000) 0 1 2 3 4 diff --git a/test/stim/passes/stim_reference_programs/qubit/qubit_reset.stim b/test/stim/passes/stim_reference_programs/qubit/qubit_reset.stim index 958e0cfea..bcecaca7e 100644 --- a/test/stim/passes/stim_reference_programs/qubit/qubit_reset.stim +++ b/test/stim/passes/stim_reference_programs/qubit/qubit_reset.stim @@ -1,3 +1,2 @@ - RZ 0 MZ(0.00000000) 0 diff --git a/test/stim/passes/stim_reference_programs/qubit/rep_code.stim b/test/stim/passes/stim_reference_programs/qubit/rep_code.stim index 9105cf433..171dd8726 100644 --- a/test/stim/passes/stim_reference_programs/qubit/rep_code.stim +++ b/test/stim/passes/stim_reference_programs/qubit/rep_code.stim @@ -1,4 +1,3 @@ - RZ 0 1 2 3 4 CX 0 1 2 3 CX 2 1 4 3 diff --git a/test/stim/passes/stim_reference_programs/qubit/u3_gates.stim b/test/stim/passes/stim_reference_programs/qubit/u3_gates.stim index 4764563d0..cfc100b0d 100644 --- a/test/stim/passes/stim_reference_programs/qubit/u3_gates.stim +++ b/test/stim/passes/stim_reference_programs/qubit/u3_gates.stim @@ -1,4 +1,3 @@ - Z 0 SQRT_X_DAG 0 SQRT_X_DAG 0 diff --git a/test/stim/passes/stim_reference_programs/qubit/u3_to_clifford.stim b/test/stim/passes/stim_reference_programs/qubit/u3_to_clifford.stim index 6eb044c65..17a51d338 100644 --- a/test/stim/passes/stim_reference_programs/qubit/u3_to_clifford.stim +++ b/test/stim/passes/stim_reference_programs/qubit/u3_to_clifford.stim @@ -1,3 +1,2 @@ - H 0 MZ(0.00000000) 0 diff --git a/test/stim/passes/test_squin_noise_to_stim.py b/test/stim/passes/test_squin_noise_to_stim.py index dc3346d5a..04df28378 100644 --- a/test/stim/passes/test_squin_noise_to_stim.py +++ b/test/stim/passes/test_squin_noise_to_stim.py @@ -252,10 +252,11 @@ def test(): SquinToStimPass(test.dialects)(test) - emit = EmitStimMain(correlation_identifier_offset=10) + buf = io.StringIO() + emit = EmitStimMain(stim.main, correlation_identifier_offset=10, io=buf) emit.initialize() - emit.run(mt=test, args=()) - stim_str = emit.get_output().strip() + emit.run(test) + stim_str = buf.getvalue().strip() assert stim_str == "I_ERROR[correlated_loss:10](0.10000000) 0 1 2 3"