Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ requires-python = ">=3.10"
dependencies = [
"numpy>=1.22.0",
"scipy>=1.13.1",
"kirin-toolchain~=0.16.0",
"kirin-toolchain~=0.17.0",
"rich>=13.9.4",
"pydantic>=1.3.0,<2.11.0",
"pandas>=2.2.3",
Expand Down
10 changes: 5 additions & 5 deletions src/bloqade/noise/native/rewrite.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
from dataclasses import dataclass

from kirin import ir
from kirin.rewrite import abc, dce, walk, result, fixpoint
from kirin.rewrite import abc, dce, walk, fixpoint
from kirin.passes.abc import Pass

from .stmts import PauliChannel, CZPauliChannel, AtomLossChannel
from ._dialect import dialect


class RemoveNoiseRewrite(abc.RewriteRule):
def rewrite_Statement(self, node: ir.Statement) -> result.RewriteResult:
def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
if isinstance(node, (AtomLossChannel, PauliChannel, CZPauliChannel)):
node.delete()
return result.RewriteResult(has_done_something=True)
return abc.RewriteResult(has_done_something=True)

return result.RewriteResult()
return abc.RewriteResult()


@dataclass
class RemoveNoisePass(Pass):
name = "remove-noise"

def unsafe_run(self, mt: ir.Method) -> result.RewriteResult:
def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult:
delete_walk = walk.Walk(RemoveNoiseRewrite())
dce_walk = fixpoint.Fixpoint(walk.Walk(dce.DeadCodeElimination()))

Expand Down
4 changes: 2 additions & 2 deletions src/bloqade/pyqrack/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def run(
"""
fold = Fold(mt.dialects)
fold(mt)
return self._get_interp(mt).run(mt, args, kwargs).expect()
return self._get_interp(mt).run(mt, args, kwargs)

def multi_run(
self,
Expand All @@ -107,6 +107,6 @@ def multi_run(
interpreter = self._get_interp(mt)
batched_results = []
for _ in range(_shots):
batched_results.append(interpreter.run(mt, args, kwargs).expect())
batched_results.append(interpreter.run(mt, args, kwargs))

return batched_results
26 changes: 9 additions & 17 deletions src/bloqade/qasm2/_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,33 +61,25 @@ def reset(qarg: Qubit) -> None:


@overload
def measure(qarg: Qubit, cbit: Bit) -> None:
"""
Measure the qubit `qarg` and store the result in the classical bit `cbit`.

Args:
qarg: The qubit to measure.
cbit: The classical bit to store the result in.
"""
...
def measure(qreg: QReg, creg: CReg) -> None: ...


@overload
def measure(qarg: QReg, carg: CReg) -> None:
def measure(qarg: Qubit, cbit: Bit) -> None: ...


@wraps(core.Measure)
def measure(qarg, cbit) -> None:
"""
Measure each qubit in the quantum register `qarg` and store the result in the classical register `carg`.
Measure the qubit `qarg` and store the result in the classical bit `cbit`.

Args:
qarg: The quantum register to measure.
carg: The classical bit to store the result in.
qarg: The qubit to measure.
cbit: The classical bit to store the result in.
"""
...


@wraps(core.Measure)
def measure(qarg, carg) -> None: ...


@wraps(uop.CX)
def cx(ctrl: Qubit, qarg: Qubit) -> None:
"""
Expand Down
7 changes: 4 additions & 3 deletions src/bloqade/qasm2/dialects/expr/_emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,17 @@
self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.GateFunction
):

cparams, qparams = [], []
args, cparams, qparams = [], [], []

Check warning on line 22 in src/bloqade/qasm2/dialects/expr/_emit.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/qasm2/dialects/expr/_emit.py#L22

Added line #L22 was not covered by tests
for arg in stmt.body.blocks[0].args[1:]:
name = frame.get(arg)
name = frame.get_typed(arg, ast.Name)
args.append(name)

Check warning on line 25 in src/bloqade/qasm2/dialects/expr/_emit.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/qasm2/dialects/expr/_emit.py#L24-L25

Added lines #L24 - L25 were not covered by tests
if not isinstance(name, ast.Name):
raise EmitError("expected ast.Name")
if arg.type.is_subseteq(QubitType):
qparams.append(name.id)
else:
cparams.append(name.id)
emit.run_ssacfg_region(frame, stmt.body)
emit.run_ssacfg_region(frame, stmt.body, args)

Check warning on line 32 in src/bloqade/qasm2/dialects/expr/_emit.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/qasm2/dialects/expr/_emit.py#L32

Added line #L32 was not covered by tests
emit.output = ast.Gate(
name=stmt.sym_name,
cparams=cparams,
Expand Down
6 changes: 4 additions & 2 deletions src/bloqade/qasm2/emit/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ def initialize(self) -> Self:
)
return self

def new_frame(self, code: ir.Statement) -> EmitQASM2Frame:
return EmitQASM2Frame.from_func_like(code)
def initialize_frame(
self, code: ir.Statement, *, has_parent_access: bool = False
) -> EmitQASM2Frame[StmtType]:
return EmitQASM2Frame(code, has_parent_access=has_parent_access)

def run_method(
self, method: ir.Method, args: tuple[ast.Node | None, ...]
Expand Down
8 changes: 6 additions & 2 deletions src/bloqade/qasm2/emit/gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,16 @@
def emit_func(
self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: func.Function
):
emit.run_ssacfg_region(frame, stmt.body)
args_ssa = stmt.args
print(stmt.args)
emit.run_ssacfg_region(frame, stmt.body, frame.get_values(args_ssa))

Check warning on line 96 in src/bloqade/qasm2/emit/gate.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/qasm2/emit/gate.py#L94-L96

Added lines #L94 - L96 were not covered by tests

cparams, qparams = [], []
for arg in stmt.args:
for arg in args_ssa:

Check warning on line 99 in src/bloqade/qasm2/emit/gate.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/qasm2/emit/gate.py#L99

Added line #L99 was not covered by tests
if arg.type.is_subseteq(QubitType):
qparams.append(frame.get(arg))
else:
cparams.append(frame.get(arg))

emit.output = ast.Gate(stmt.sym_name, cparams, qparams, frame.body)
return ()
34 changes: 19 additions & 15 deletions src/bloqade/qasm2/emit/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
):
from bloqade.qasm2.dialects import glob, noise, parallel

emit.run_ssacfg_region(frame, stmt.body)
emit.run_ssacfg_region(frame, stmt.body, ())
if emit.dialects.data.intersection(
(parallel.dialect, glob.dialect, noise.dialect)
):
Expand All @@ -51,12 +51,14 @@
self, emit: EmitQASM2Main, frame: EmitQASM2Frame, stmt: cf.ConditionalBranch
):
cond = emit.assert_node(ast.Cmp, frame.get(stmt.cond))
body_frame = emit.new_frame(stmt)
body_frame.entries.update(frame.entries)
body_frame.set_values(
stmt.then_successor.args, frame.get_values(stmt.then_arguments)
)
emit.emit_block(body_frame, stmt.then_successor)

with emit.new_frame(stmt) as body_frame:
body_frame.entries.update(frame.entries)
body_frame.set_values(

Check warning on line 57 in src/bloqade/qasm2/emit/main.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/qasm2/emit/main.py#L55-L57

Added lines #L55 - L57 were not covered by tests
stmt.then_successor.args, frame.get_values(stmt.then_arguments)
)
emit.emit_block(body_frame, stmt.then_successor)

Check warning on line 60 in src/bloqade/qasm2/emit/main.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/qasm2/emit/main.py#L60

Added line #L60 was not covered by tests

frame.body.append(
ast.IfStmt(
cond,
Expand Down Expand Up @@ -91,15 +93,17 @@
)

cond = emit.assert_node(ast.Cmp, frame.get(stmt.cond))
then_frame = emit.new_frame(stmt)
then_frame.entries.update(frame.entries)
emit.emit_block(then_frame, stmt.then_body.blocks[0])
frame.body.append(
ast.IfStmt(
cond,
body=then_frame.body, # type: ignore

with emit.new_frame(stmt) as then_frame:
then_frame.entries.update(frame.entries)
emit.emit_block(then_frame, stmt.then_body.blocks[0])
frame.body.append(
ast.IfStmt(
cond,
body=then_frame.body, # type: ignore
)
)
)

term = stmt.then_body.blocks[0].last_stmt
if isinstance(term, scf.Yield):
return then_frame.get_values(term.values)
Expand Down
8 changes: 2 additions & 6 deletions src/bloqade/qasm2/emit/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,7 @@

Py2QASM(entry.dialects)(entry)
target_main = EmitQASM2Main(self.main_target)
target_main.run(
entry, tuple(ast.Name(name) for name in entry.arg_names[1:])
).expect()
target_main.run(entry, tuple(ast.Name(name) for name in entry.arg_names[1:]))

main_program = target_main.output
assert main_program is not None, f"failed to emit {entry.sym_name}"
Expand Down Expand Up @@ -133,9 +131,7 @@

Py2QASM(fn.dialects)(fn)

target_gate.run(
fn, tuple(ast.Name(name) for name in fn.arg_names[1:])
).expect()
target_gate.run(fn, tuple(ast.Name(name) for name in fn.arg_names[1:]))

Check warning on line 134 in src/bloqade/qasm2/emit/target.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/qasm2/emit/target.py#L134

Added line #L134 was not covered by tests
assert target_gate.output is not None, f"failed to emit {fn.sym_name}"
extra.append(target_gate.output)

Expand Down
2 changes: 1 addition & 1 deletion src/bloqade/qasm2/passes/fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from kirin.analysis import const
from kirin.dialects import scf, ilist
from kirin.ir.method import Method
from kirin.rewrite.result import RewriteResult
from kirin.rewrite.abc import RewriteResult

from bloqade.qasm2.dialects import expr

Expand Down
6 changes: 3 additions & 3 deletions src/bloqade/qasm2/passes/glob.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

from kirin import ir
from kirin.rewrite import cse, dce, walk, result
from kirin.rewrite import abc, cse, dce, walk
from kirin.passes.abc import Pass
from kirin.passes.fold import Fold
from kirin.rewrite.fixpoint import Fixpoint
Expand Down Expand Up @@ -54,7 +54,7 @@ def generate_rule(self, mt: ir.Method) -> GlobalToUOpRule:
frame, _ = address.AddressAnalysis(mt.dialects).run_analysis(mt)
return GlobalToUOpRule(frame.entries)

def unsafe_run(self, mt: ir.Method) -> result.RewriteResult:
def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult:
rewriter = walk.Walk(self.generate_rule(mt))
result = rewriter.rewrite(mt.code)

Expand Down Expand Up @@ -106,7 +106,7 @@ def generate_rule(self, mt: ir.Method) -> GlobalToParallelRule:
frame, _ = address.AddressAnalysis(mt.dialects).run_analysis(mt)
return GlobalToParallelRule(frame.entries)

def unsafe_run(self, mt: ir.Method) -> result.RewriteResult:
def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult:
rewriter = walk.Walk(self.generate_rule(mt))
result = rewriter.rewrite(mt.code)

Expand Down
2 changes: 1 addition & 1 deletion src/bloqade/qasm2/passes/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
DeadCodeElimination,
CommonSubexpressionElimination,
)
from kirin.rewrite.result import RewriteResult
from kirin.rewrite.abc import RewriteResult

from bloqade.noise import native
from bloqade.analysis import address
Expand Down
6 changes: 3 additions & 3 deletions src/bloqade/qasm2/passes/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
ConstantFold,
DeadCodeElimination,
CommonSubexpressionElimination,
result,
abc,
)
from kirin.analysis import const

Expand Down Expand Up @@ -84,7 +84,7 @@ def generate_rule(self, mt: ir.Method) -> ParallelToUOpRule:

return ParallelToUOpRule(id_map=id_map, address_analysis=frame.entries)

def unsafe_run(self, mt: ir.Method) -> result.RewriteResult:
def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult:
result = Walk(self.generate_rule(mt)).rewrite(mt.code)
rule = Chain(
ConstantFold(),
Expand Down Expand Up @@ -140,7 +140,7 @@ def test():
def __post_init__(self):
self.constprop = const.Propagate(self.dialects)

def unsafe_run(self, mt: ir.Method) -> result.RewriteResult:
def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult:
result = Walk(RaiseRegisterRule()).rewrite(mt.code)

# do not run the parallelization because registers are not at the top
Expand Down
3 changes: 1 addition & 2 deletions src/bloqade/qasm2/passes/py2qasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from kirin.passes import Pass
from kirin.rewrite import Walk, Fixpoint
from kirin.dialects import py, math
from kirin.rewrite.abc import RewriteRule
from kirin.rewrite.result import RewriteResult
from kirin.rewrite.abc import RewriteRule, RewriteResult

from bloqade.qasm2.dialects import core, expr

Expand Down
3 changes: 1 addition & 2 deletions src/bloqade/qasm2/passes/qasm2py.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from kirin.passes import Pass
from kirin.rewrite import Walk, Fixpoint
from kirin.dialects import py, math
from kirin.rewrite.abc import RewriteRule
from kirin.rewrite.result import RewriteResult
from kirin.rewrite.abc import RewriteRule, RewriteResult

from bloqade.qasm2.dialects import core, expr

Expand Down
12 changes: 6 additions & 6 deletions src/bloqade/qasm2/rewrite/desugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,27 @@

from kirin import ir
from kirin.passes import Pass
from kirin.rewrite import abc, walk, result
from kirin.rewrite import abc, walk
from kirin.dialects import py

from bloqade.qasm2.dialects import core


class IndexingDesugarRule(abc.RewriteRule):
def rewrite_Statement(self, node: ir.Statement) -> result.RewriteResult:
def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
if isinstance(node, py.indexing.GetItem):
if node.obj.type.is_subseteq(core.QRegType):
node.replace_by(core.QRegGet(reg=node.obj, idx=node.index))
return result.RewriteResult(has_done_something=True)
return abc.RewriteResult(has_done_something=True)
elif node.obj.type.is_subseteq(core.CRegType):
node.replace_by(core.CRegGet(reg=node.obj, idx=node.index))
return result.RewriteResult(has_done_something=True)
return abc.RewriteResult(has_done_something=True)

return result.RewriteResult()
return abc.RewriteResult()


@dataclass
class IndexingDesugarPass(Pass):
def unsafe_run(self, mt: ir.Method) -> result.RewriteResult:
def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult:

return walk.Walk(IndexingDesugarRule()).rewrite(mt.code)
Loading